Commit 856f87a1 authored by Wil Kong's avatar Wil Kong
Browse files

Add unittest for multi-tensor apply FusedAdam.

parent 3f86316e
......@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self):
pass
def gen_param_optim(self, tensors, adam_option):
def gen_param_optim(self, tensors, ref_adam_option, tst_adam_option=None):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)
ref_optim = torch.optim.Adam(ref_param, **ref_adam_option)
if tst_adam_option:
tst_optim = apex.optimizers.FusedAdam(tst_param, **tst_adam_option)
else:
tst_optim = apex.optimizers.FusedAdam(tst_param, **ref_adam_option)
return (ref_param, tst_param, ref_optim, tst_optim)
......@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
max_abs_diff_p = (p_ref - p_tst.type(p_ref.type())).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst.type(p_ref.type())) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
......@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_multi_tensor(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
ref_adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tst_adam_option = dict(ref_adam_option, **{'use_mt':True})
tensors = []
fp16_params = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
fp16_params.append(torch.nn.Parameter(tensors[-1].clone().half()))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, ref_adam_option, tst_adam_option)
for i in range(self.iters):
half_grads = self.gen_mixed_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step(grads=half_grads, output_params=fp16_params)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
fp16_params)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment