Unverified Commit ddaafd78 authored by noise-field's avatar noise-field Committed by GitHub
Browse files

Fix mlflow param overflow clean (#10071)

* Unify logging with f-strings

* Get limits from MLflow rather than hardcode

* Add a check for parameter length overflow

Also constants are marked as internal

* Don't stop run in on_train_end

This causes bad behaviour when there is a seprarte validation step:
validation gets recorded as separate run.

* Fix style
parent ece6c514
...@@ -707,12 +707,13 @@ class MLflowCallback(TrainerCallback): ...@@ -707,12 +707,13 @@ class MLflowCallback(TrainerCallback):
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__. A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
""" """
MAX_LOG_SIZE = 100
def __init__(self): def __init__(self):
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`." assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
import mlflow import mlflow
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
self._initialized = False self._initialized = False
self._log_artifacts = False self._log_artifacts = False
self._ml_flow = mlflow self._ml_flow = mlflow
...@@ -738,10 +739,21 @@ class MLflowCallback(TrainerCallback): ...@@ -738,10 +739,21 @@ class MLflowCallback(TrainerCallback):
if hasattr(model, "config") and model.config is not None: if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict() model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict} combined_dict = {**model_config, **combined_dict}
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
logger.warning(
f"Trainer is attempting to log a value of "
f'"{value}" for key "{name}" as a parameter. '
f"MLflow's log_param() only accepts values no longer than "
f"250 characters so we dropped this attribute."
)
del combined_dict[name]
# MLflow cannot log more than 100 values in one go, so we have to split it # MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items()) combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE): for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE])) self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
self._initialized = True self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
...@@ -757,13 +769,10 @@ class MLflowCallback(TrainerCallback): ...@@ -757,13 +769,10 @@ class MLflowCallback(TrainerCallback):
self._ml_flow.log_metric(k, v, step=state.global_step) self._ml_flow.log_metric(k, v, step=state.global_step)
else: else:
logger.warning( logger.warning(
"Trainer is attempting to log a value of " f"Trainer is attempting to log a value of "
'"%s" of type %s for key "%s" as a metric. ' f'"{v}" of type {type(v)} for key "{k}" as a metric. '
"MLflow's log_metric() only accepts float and " f"MLflow's log_metric() only accepts float and "
"int types so we dropped this attribute.", f"int types so we dropped this attribute."
v,
type(v),
k,
) )
def on_train_end(self, args, state, control, **kwargs): def on_train_end(self, args, state, control, **kwargs):
...@@ -771,13 +780,12 @@ class MLflowCallback(TrainerCallback): ...@@ -771,13 +780,12 @@ class MLflowCallback(TrainerCallback):
if self._log_artifacts: if self._log_artifacts:
logger.info("Logging artifacts. This may take time.") logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir) self._ml_flow.log_artifacts(args.output_dir)
self._ml_flow.end_run()
def __del__(self): def __del__(self):
# if the previous run is not terminated correctly, the fluent API will # if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed # not let you start a new run before the previous one is killed
if self._ml_flow.active_run is not None: if self._ml_flow.active_run is not None:
self._ml_flow.end_run(status="KILLED") self._ml_flow.end_run()
INTEGRATION_TO_CALLBACK = { INTEGRATION_TO_CALLBACK = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment