Unverified Commit eb186bc1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Small fixes to HP search (#7839)

parent d8ca57d2
...@@ -520,7 +520,7 @@ class Trainer: ...@@ -520,7 +520,7 @@ class Trainer:
): ):
if self.hp_search_backend is None or trial is None: if self.hp_search_backend is None or trial is None:
return return
self.objective = self.compute_objective(metrics) self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA: if self.hp_search_backend == HPSearchBackend.OPTUNA:
trial.report(self.objective, epoch) trial.report(self.objective, epoch)
if trial.should_prune(): if trial.should_prune():
......
...@@ -112,6 +112,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: ...@@ -112,6 +112,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
""" """
loss = metrics.pop("eval_loss", None) loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None) _ = metrics.pop("epoch", None)
_ = metrics.pop("total_flos", None)
return loss if len(metrics) == 0 else sum(metrics.values()) return loss if len(metrics) == 0 else sum(metrics.values())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment