Unverified Commit 238521b0 authored by Kai Fricke's avatar Kai Fricke Committed by GitHub
Browse files

Replace NotebookProgressReporter by ProgressReporter in Ray Tune run (#12357)

* Replace NotebookProgressReporter by ProgressReporter in Ray Tune run

* Move to local import
parent 332a2458
......@@ -42,7 +42,7 @@ if _has_comet:
_has_comet = False
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import TrainerCallback # noqa: E402
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
......@@ -153,6 +153,14 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
import ray
def _objective(trial, local_trainer, checkpoint_dir=None):
try:
from transformers.utils.notebook import NotebookProgressCallback
if local_trainer.pop_callback(NotebookProgressCallback):
local_trainer.add_callback(ProgressCallback)
except ModuleNotFoundError:
pass
checkpoint = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment