Unverified Commit 41f3133a authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Extract metric_key_prefix during NotebookProgressCallback.on_evaluate (#11347)

* Pass metric_key_prefix as kwarg to on_evaluate

* Replace eval_loss with metric_key_prefix_loss

* Default to "eval" if metric_key_prefix not in kwargs

* Add kwargs to CallbackHandler.on_evaluate signature

* Revert "Add kwargs to CallbackHandler.on_evaluate signature"

This reverts commit 8d4c85ed512f558f7579d36771e907b3379947b7.

* Revert "Pass metric_key_prefix as kwarg to on_evaluate"

This reverts commit 7766bfe2718601230ae593d37b1317bd53cfc075.

* Extract metric_key_prefix from metrics
parent dabeb152
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import collections import collections
import re
import time import time
from typing import Optional from typing import Optional
...@@ -308,7 +309,7 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -308,7 +309,7 @@ class NotebookProgressCallback(TrainerCallback):
def on_evaluate(self, args, state, control, metrics=None, **kwargs): def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None: if self.training_tracker is not None:
values = {"Training Loss": "No log"} values = {"Training Loss": "No log", "Validation Loss": "No log"}
for log in reversed(state.log_history): for log in reversed(state.log_history):
if "loss" in log: if "loss" in log:
values["Training Loss"] = log["loss"] values["Training Loss"] = log["loss"]
...@@ -318,13 +319,16 @@ class NotebookProgressCallback(TrainerCallback): ...@@ -318,13 +319,16 @@ class NotebookProgressCallback(TrainerCallback):
values["Epoch"] = int(state.epoch) values["Epoch"] = int(state.epoch)
else: else:
values["Step"] = state.global_step values["Step"] = state.global_step
values["Validation Loss"] = metrics["eval_loss"] metric_key_prefix = "eval"
for k in metrics:
if k.endswith("_loss"):
metric_key_prefix = re.sub(r"\_loss$", "", k)
_ = metrics.pop("total_flos", None) _ = metrics.pop("total_flos", None)
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
_ = metrics.pop("eval_runtime", None) _ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop("eval_samples_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
for k, v in metrics.items(): for k, v in metrics.items():
if k == "eval_loss": if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v values["Validation Loss"] = v
else: else:
splits = k.split("_") splits = k.split("_")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment