"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2cc8cf6ce7ae0416561acbb639df4bbc5f409b6f"
Unverified Commit 6287c929 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[lm examples] fix overflow in perplexity calc (#11855)

* fix overflow in perplexity calc

* use inf

* fix
parent 7630c11f
...@@ -440,7 +440,10 @@ def main(): ...@@ -440,7 +440,10 @@ def main():
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"]) try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
......
...@@ -442,7 +442,10 @@ def main(): ...@@ -442,7 +442,10 @@ def main():
losses = torch.cat(losses) losses = torch.cat(losses)
losses = losses[: len(eval_dataset)] losses = losses[: len(eval_dataset)]
perplexity = math.exp(torch.mean(losses)) try:
perplexity = math.exp(torch.mean(losses))
except OverflowError:
perplexity = float("inf")
logger.info(f"epoch {epoch}: perplexity: {perplexity}") logger.info(f"epoch {epoch}: perplexity: {perplexity}")
......
...@@ -469,7 +469,10 @@ def main(): ...@@ -469,7 +469,10 @@ def main():
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"]) try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
......
...@@ -486,7 +486,10 @@ def main(): ...@@ -486,7 +486,10 @@ def main():
losses = torch.cat(losses) losses = torch.cat(losses)
losses = losses[: len(eval_dataset)] losses = losses[: len(eval_dataset)]
perplexity = math.exp(torch.mean(losses)) try:
perplexity = math.exp(torch.mean(losses))
except OverflowError:
perplexity = float("inf")
logger.info(f"epoch {epoch}: perplexity: {perplexity}") logger.info(f"epoch {epoch}: perplexity: {perplexity}")
......
...@@ -445,7 +445,10 @@ def main(): ...@@ -445,7 +445,10 @@ def main():
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"]) try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment