"lightx2v/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "4e735704d3ce4169578743932681ad8727e1362b"
Commit ccb6947d authored by thomwolf's avatar thomwolf
Browse files

optimization tests

parent e4f9dca0
......@@ -96,8 +96,10 @@ def train(args, train_dataset, model, tokenizer):
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
......@@ -129,7 +131,7 @@ def train(args, train_dataset, model, tokenizer):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
......@@ -148,8 +150,10 @@ def train(args, train_dataset, model, tokenizer):
logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
return global_step, tr_loss / global_step
......@@ -164,11 +168,10 @@ def evaluate(args, model, tokenizer, prefix=""):
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
""" Evaluate the model """
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * args.n_gpu
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
......@@ -177,7 +180,7 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
......@@ -287,6 +290,8 @@ def main():
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
......@@ -409,6 +414,8 @@ def main():
elif args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("Training/evaluation parameters %s", args)
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
......@@ -438,15 +445,15 @@ def main():
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir + './' + WEIGHTS_NAME]
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
results = {}
for checkpoint in checkpoints:
global_step = int(checkpoint.split('-')[-1])
global_step = checkpoint.split('-')[-1]
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=global_step)
......
......@@ -45,9 +45,18 @@ class ExamplesTests(unittest.TestCase):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = ["run_glue.py", "--data_dir=./examples/tests_samples/MRPC/",
"--task_name=mrpc", "--do_train", "--do_eval", "--output_dir=./examples/tests_samples/temp_dir",
"--train_batch_size=4", "--eval_batch_size=2", "--num_train_epochs=2.0", "--overwrite_output_dir"]
testargs = ["run_glue.py",
"--data_dir=./examples/tests_samples/MRPC/",
"--task_name=mrpc",
"--do_train",
"--do_eval",
"--output_dir=./examples/tests_samples/temp_dir",
"--per_gpu_train_batch_size=2",
"--per_gpu_eval_batch_size=1",
"--learning_rate=1e-4",
"--max_steps=10",
"--warmup_steps=2",
"--overwrite_output_dir"]
model_name = "--model_name=bert-base-uncased"
with patch.object(sys, 'argv', testargs + [model_name]):
result = run_glue.main()
......
......@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class ConstantLRSchedule(LambdaLR):
def __init__(self, optimizer, last_epoch=-1):
super(ConstantLRSchedule, self).__init__(optimizer, lambda x: x, last_epoch=last_epoch)
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
class WarmupCosineSchedule(LambdaLR):
"""
......@@ -42,10 +42,10 @@ class WarmupCosineSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return float(step) / float(max(1.0, warmup_steps))
else:
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * cycles * 2 * progress))
progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps)) # progress after warmup
return 0.5 * (1. + math.cos(math.pi * float(cycles) * 2.0 * progress))
super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
......@@ -59,11 +59,12 @@ class WarmupCosineWithHardRestartsSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return float(step) / float(max(1, warmup_steps))
else:
progress = (step - warmup_steps) / max(1, t_total - warmup_steps) # progress after warmup
ret = 0.5 * (1. + math.cos(math.pi * ((cycles * progress) % 1)))
return ret
progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps)) # progress after warmup
if progress >= 1.0:
return 0.0
return 0.5 * (1. + math.cos(math.pi * ((float(cycles) * progress) % 1.0)))
super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
......@@ -77,7 +78,7 @@ class WarmupConstantSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
return float(step) / float(max(1.0, warmup_steps))
return 1.
super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
......@@ -92,8 +93,8 @@ class WarmupLinearSchedule(LambdaLR):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
return (t_total - step) / max(1, t_total - warmup_steps)
return float(step) / float(max(1, warmup_steps))
return float(t_total - step) / float(max(1.0, t_total - warmup_steps))
super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch)
......
......@@ -26,6 +26,13 @@ from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSched
import numpy as np
def unwrap_schedule(scheduler, num_steps=10):
lrs = []
for _ in range(num_steps):
scheduler.step()
lrs.append(scheduler.get_lr())
return lrs
class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol):
......@@ -38,9 +45,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = AdamW(params=[w], lr=2e-1,
weight_decay=0.0,
max_grad_norm=-1)
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
for _ in range(100):
loss = criterion(w, target)
loss.backward()
......@@ -51,29 +56,49 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
def test_sched_init(self):
m = torch.nn.Linear(50, 50)
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = AdamW(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = AdamW(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
# shouldn't fail
class WarmupCosineWithRestartsTest(unittest.TestCase):
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
x = np.arange(0, 1000)
y = [m.get_lr(xe) for xe in x]
y = np.asarray(y)
expected_zeros = y[[0, 200, 400, 600, 800]]
print(expected_zeros)
expected_ones = y[[50, 250, 450, 650, 850]]
print(expected_ones)
self.assertTrue(np.allclose(expected_ones, 1))
self.assertTrue(np.allclose(expected_zeros, 0))
optimizer = AdamW(m.parameters(), lr=10.)
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol):
self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol)
def test_constant_scheduler(self):
scheduler = ConstantLRSchedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_constant_scheduler(self):
scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_linear_scheduler(self):
scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_warmup_cosine_scheduler(self):
scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment