Unverified Commit c849a61e authored by Nicolas Brousse's avatar Nicolas Brousse Committed by GitHub
Browse files

Fix MLflowCallback and add support for MLFLOW_EXPERIMENT_NAME (#17091)

* Fix use of mlflow.active_run() and add proper support for MLFLOW_EXPERIMENT_NAME

* Fix code style (make style)
parent 99289c08
......@@ -781,17 +781,26 @@ class MLflowCallback(TrainerCallback):
Environment:
HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
Whether to use MLflow .log_artifact() facility to log artifacts.
This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy
whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it
without a remote storage will just copy the files to your artifact location.
Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
[`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
storage will just copy the files to your artifact location.
MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
Whether to use an MLflow experiment_name under which to launch the run. Default to "None" which will
point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
to be activated. If an experiment with this name does not exist, a new experiment with this name is
created.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}")
if state.is_world_process_zero:
if self._ml_flow.active_run is None:
if self._ml_flow.active_run() is None:
if experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists
self._ml_flow.set_experiment(experiment_name)
self._ml_flow.start_run(run_name=args.run_name)
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
......@@ -844,7 +853,7 @@ class MLflowCallback(TrainerCallback):
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None:
if self._ml_flow.active_run() is not None:
self._ml_flow.end_run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment