Unverified Commit 54e9ce78 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Fix PABEE division by zero error (#5233)

* Fix PABEE division by zero error

* patience=0 by default
parent 9022ef02
......@@ -254,12 +254,15 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
def evaluate(args, model, tokenizer, prefix="", patience=0):
# PABEE STATS
if args.model_type == "albert":
model.albert.set_regression_threshold(args.regression_threshold)
model.albert.set_patience(patience)
model.albert.reset_stats()
elif args.model_type == "bert":
model.bert.set_regression_threshold(args.regression_threshold)
model.bert.set_patience(patience)
model.bert.reset_stats()
else:
raise NotImplementedError()
......@@ -331,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
print(" %s = %s" % (key, str(result[key])))
writer.write("%s = %s\n" % (key, str(result[key])))
if args.eval_all_checkpoints:
if args.eval_all_checkpoints and patience != 0:
if args.model_type == "albert":
model.albert.log_stats()
elif args.model_type == "bert":
......@@ -690,15 +693,7 @@ def main():
print(f"Evaluation for checkpoint {prefix}")
for patience in patience_list:
if args.model_type == "albert":
model.albert.set_regression_threshold(args.regression_threshold)
model.albert.set_patience(patience)
elif args.model_type == "bert":
model.bert.set_regression_threshold(args.regression_threshold)
model.bert.set_patience(patience)
else:
raise NotImplementedError()
result = evaluate(args, model, tokenizer, prefix=prefix)
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result)
return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment