Unverified Commit 12d1f077 authored by Kshitiz Sharma's avatar Kshitiz Sharma Committed by GitHub
Browse files

integrations: mlflow: skip start_run() if a run is already active and sanity...


integrations: mlflow: skip start_run() if a run is already active and sanity check on enabling integration (#16131)

* integrations: mlflow: skip start_run() call if a run is already active

* integrations: typo fix
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 47cccb53
...@@ -96,6 +96,8 @@ def is_azureml_available(): ...@@ -96,6 +96,8 @@ def is_azureml_available():
def is_mlflow_available(): def is_mlflow_available():
if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
return False
return importlib.util.find_spec("mlflow") is not None return importlib.util.find_spec("mlflow") is not None
...@@ -758,7 +760,8 @@ class AzureMLCallback(TrainerCallback): ...@@ -758,7 +760,8 @@ class AzureMLCallback(TrainerCallback):
class MLflowCallback(TrainerCallback): class MLflowCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
""" """
def __init__(self): def __init__(self):
...@@ -789,7 +792,8 @@ class MLflowCallback(TrainerCallback): ...@@ -789,7 +792,8 @@ class MLflowCallback(TrainerCallback):
if log_artifacts in {"TRUE", "1"}: if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True self._log_artifacts = True
if state.is_world_process_zero: if state.is_world_process_zero:
self._ml_flow.start_run(run_name=args.run_name) if self._ml_flow.active_run is None:
self._ml_flow.start_run(run_name=args.run_name)
combined_dict = args.to_dict() combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None: if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict() model_config = model.config.to_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment