Unverified Commit def8fb85 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #100 from FDecaYed/deyuf/optimizer_unittests

add unit tests for optimizers/fp16_optimizer
parents 0b848f0d afc8d1b2
import unittest
import torch
import apex
class TestFP16Optimizer(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff
self.iters = iters
torch.cuda.manual_seed(13337)
N, D_in, D_out = 64, 1024, 16
self.N = N
self.D_in = D_in
self.D_out = D_out
self.x = torch.randn((N, D_in), dtype=torch.float16, device='cuda')
self.ref_model = torch.nn.Linear(D_in, D_out).cuda().half()
self.tst_model = torch.nn.Linear(D_in, D_out).cuda().half()
for p,q in zip(self.tst_model.parameters(), self.ref_model.parameters()):
p.data.copy_(q.data)
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
return max_abs_diff, max_rel_diff
def test_fp16_optimizer(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters())
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, verbose=False)
tst_optim = apex.optimizers.FusedAdam(self.tst_model.parameters())
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim)
for i in range(self.iters):
ref_loss = self.ref_model(self.x).sum()
ref_optim.backward(ref_loss)
ref_optim.step()
tst_loss = self.tst_model(self.x).sum()
tst_optim.backward(tst_loss)
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(self.ref_model.parameters(), self.tst_model.parameters())
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_loss_scaling(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters())
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, static_loss_scale=128.0, verbose=False)
tst_optim = apex.optimizers.FusedAdam(self.tst_model.parameters())
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim, static_loss_scale=128.0)
for i in range(self.iters):
ref_loss = self.ref_model(self.x).sum()
ref_optim.backward(ref_loss)
ref_optim.step()
tst_loss = self.tst_model(self.x).sum()
tst_optim.backward(tst_loss)
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(self.ref_model.parameters(), self.tst_model.parameters())
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_parameter_groups(self):
ref_groups = [{'params': [self.ref_model.weight]},{'params': [self.ref_model.bias]}]
ref_optim = torch.optim.Adam(ref_groups)
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, verbose=False)
tst_groups = [{'params': [self.tst_model.weight]},{'params': [self.tst_model.bias]}]
tst_optim = apex.optimizers.FusedAdam(tst_groups)
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim)
for i in range(self.iters):
ref_loss = self.ref_model(self.x).sum()
ref_optim.backward(ref_loss)
ref_optim.step()
tst_loss = self.tst_model(self.x).sum()
tst_optim.backward(tst_loss)
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(self.ref_model.parameters(), self.tst_model.parameters())
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_grad_clip(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters())
ref_optim = apex.fp16_utils.FP16_Optimizer(ref_optim, verbose=False)
tst_optim = apex.optimizers.FusedAdam(self.tst_model.parameters(), max_grad_norm=0.01)
tst_optim = apex.optimizers.FP16_Optimizer(tst_optim)
for i in range(self.iters):
ref_loss = self.ref_model(self.x).sum()
ref_optim.backward(ref_loss)
ref_optim.clip_master_grads(0.01)
ref_optim.step()
tst_loss = self.tst_model(self.x).sum()
tst_optim.backward(tst_loss)
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(self.ref_model.parameters(), self.tst_model.parameters())
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
@unittest.skip('Not support grad being None')
def test_grad_None(self):
self.fail()
@unittest.skip('Not support same weight decay as pytorch')
def test_weight_decay(self):
self.fail()
@unittest.skip('Not support empty parameter groups')
def test_group_empty(self):
self.fail()
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment