Unverified Commit 766d4bf7 authored by Nicolas Brousse's avatar Nicolas Brousse Committed by GitHub
Browse files

Fix MLflowCallback end_run() and add support for tags and nested runs (#17130)

* ensure mlflow.end_run() is executed at end of training when mlflow.start_run() was executed by the callback

* add debug msg

* add support for MLFLOW_TAGS, MLFLOW_RUN_ID, and MLFLOW_NESTED_RUN

* update to support python 3.6+

* Validate env variables using ENV_VARS_TRUE_VALUES

* Empty-Commit
parent 2fbb2379
...@@ -16,6 +16,7 @@ Integrations with other Python libraries. ...@@ -16,6 +16,7 @@ Integrations with other Python libraries.
""" """
import functools import functools
import importlib.util import importlib.util
import json
import numbers import numbers
import os import os
import sys import sys
...@@ -772,6 +773,7 @@ class MLflowCallback(TrainerCallback): ...@@ -772,6 +773,7 @@ class MLflowCallback(TrainerCallback):
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
self._initialized = False self._initialized = False
self._auto_end_run = False
self._log_artifacts = False self._log_artifacts = False
self._ml_flow = mlflow self._ml_flow = mlflow
...@@ -790,18 +792,32 @@ class MLflowCallback(TrainerCallback): ...@@ -790,18 +792,32 @@ class MLflowCallback(TrainerCallback):
point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
to be activated. If an experiment with this name does not exist, a new experiment with this name is to be activated. If an experiment with this name does not exist, a new experiment with this name is
created. created.
MLFLOW_TAGS (`str`, *optional*):
A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'
MLFLOW_NESTED_RUN (`str`, *optional*):
Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
run.
MLFLOW_RUN_ID (`str`, *optional*):
Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
run ID and other parameters are ignored.
""" """
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
if log_artifacts in {"TRUE", "1"}: self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._log_artifacts = True self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) self._run_id = os.getenv("MLFLOW_RUN_ID", None)
logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}") logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}"
)
if state.is_world_process_zero: if state.is_world_process_zero:
if self._ml_flow.active_run() is None: if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
if experiment_name: if self._experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists # Use of set_experiment() ensure that Experiment is created if not exists
self._ml_flow.set_experiment(experiment_name) self._ml_flow.set_experiment(self._experiment_name)
self._ml_flow.start_run(run_name=args.run_name) self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
self._auto_end_run = True
combined_dict = args.to_dict() combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None: if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict() model_config = model.config.to_dict()
...@@ -821,6 +837,10 @@ class MLflowCallback(TrainerCallback): ...@@ -821,6 +837,10 @@ class MLflowCallback(TrainerCallback):
combined_dict_items = list(combined_dict.items()) combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH): for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH])) self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
mlflow_tags = os.getenv("MLFLOW_TAGS", None)
if mlflow_tags:
mlflow_tags = json.loads(mlflow_tags)
self._ml_flow.set_tags(mlflow_tags)
self._initialized = True self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
...@@ -849,11 +869,13 @@ class MLflowCallback(TrainerCallback): ...@@ -849,11 +869,13 @@ class MLflowCallback(TrainerCallback):
if self._log_artifacts: if self._log_artifacts:
logger.info("Logging artifacts. This may take time.") logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir) self._ml_flow.log_artifacts(args.output_dir)
if self._auto_end_run and self._ml_flow.active_run():
self._ml_flow.end_run()
def __del__(self): def __del__(self):
# if the previous run is not terminated correctly, the fluent API will # if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed # not let you start a new run before the previous one is killed
if self._ml_flow.active_run() is not None: if self._auto_end_run and self._ml_flow.active_run() is not None:
self._ml_flow.end_run() self._ml_flow.end_run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment