Commit 3c7e676f authored by erenup's avatar erenup
Browse files

add test related code: test the best dev acc model when model is training

parent fc741325
...@@ -169,7 +169,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -169,7 +169,7 @@ def train(args, train_dataset, model, tokenizer):
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
if results["eval_loss"] < best_dev_loss: if results["eval_acc"] < best_dev_acc:
best_dev_acc = results["eval_acc"] best_dev_acc = results["eval_acc"]
best_dev_loss = results["eval_loss"] best_dev_loss = results["eval_loss"]
best_steps = global_step best_steps = global_step
...@@ -469,12 +469,12 @@ def main(): ...@@ -469,12 +469,12 @@ def main():
model.to(args.device) model.to(args.device)
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
best_steps = 0
# Training # Training
if args.do_train: if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss, _ = train(args, train_dataset, model, tokenizer) global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
...@@ -522,7 +522,7 @@ def main(): ...@@ -522,7 +522,7 @@ def main():
if not args.do_train: if not args.do_train:
args.output_dir = args.model_name_or_path args.output_dir = args.model_name_or_path
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
...@@ -533,7 +533,8 @@ def main(): ...@@ -533,7 +533,8 @@ def main():
result = evaluate(args, model, tokenizer, prefix=global_step, test=True) result = evaluate(args, model, tokenizer, prefix=global_step, test=True)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
if best_steps:
logger.info("best steps of eval acc is the following checkpoints: %s", best_steps)
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment