Unverified Commit 54e9ce78 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Fix PABEE division by zero error (#5233)

* Fix PABEE division by zero error

* patience=0 by default
parent 9022ef02
...@@ -254,12 +254,15 @@ def train(args, train_dataset, model, tokenizer): ...@@ -254,12 +254,15 @@ def train(args, train_dataset, model, tokenizer):
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""): def evaluate(args, model, tokenizer, prefix="", patience=0):
# PABEE STATS
if args.model_type == "albert": if args.model_type == "albert":
model.albert.set_regression_threshold(args.regression_threshold)
model.albert.set_patience(patience)
model.albert.reset_stats() model.albert.reset_stats()
elif args.model_type == "bert": elif args.model_type == "bert":
model.bert.set_regression_threshold(args.regression_threshold)
model.bert.set_patience(patience)
model.bert.reset_stats() model.bert.reset_stats()
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -331,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -331,7 +334,7 @@ def evaluate(args, model, tokenizer, prefix=""):
print(" %s = %s" % (key, str(result[key]))) print(" %s = %s" % (key, str(result[key])))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
if args.eval_all_checkpoints: if args.eval_all_checkpoints and patience != 0:
if args.model_type == "albert": if args.model_type == "albert":
model.albert.log_stats() model.albert.log_stats()
elif args.model_type == "bert": elif args.model_type == "bert":
...@@ -690,15 +693,7 @@ def main(): ...@@ -690,15 +693,7 @@ def main():
print(f"Evaluation for checkpoint {prefix}") print(f"Evaluation for checkpoint {prefix}")
for patience in patience_list: for patience in patience_list:
if args.model_type == "albert": result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
model.albert.set_regression_threshold(args.regression_threshold)
model.albert.set_patience(patience)
elif args.model_type == "bert":
model.bert.set_regression_threshold(args.regression_threshold)
model.bert.set_patience(patience)
else:
raise NotImplementedError()
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment