Unverified Commit 06901162 authored by eajechiloae's avatar eajechiloae Committed by GitHub
Browse files

ClearMLCallback enhancements: support multiple runs and handle logging better (#28559)



* add clearml tracker

* support multiple train runs

* remove bad code

* add UI entries for config/hparams overrides

* handle models in different tasks

* run ruff format

* tidy code based on code review

---------
Co-authored-by: default avatarEugen Ajechiloae <eugenajechiloae@gmail.com>
parent ba3264b4
...@@ -24,7 +24,7 @@ import pickle ...@@ -24,7 +24,7 @@ import pickle
import shutil import shutil
import sys import sys
import tempfile import tempfile
from dataclasses import asdict from dataclasses import asdict, fields
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
...@@ -1438,6 +1438,24 @@ class ClearMLCallback(TrainerCallback): ...@@ -1438,6 +1438,24 @@ class ClearMLCallback(TrainerCallback):
Whether to log models as artifacts during training. Whether to log models as artifacts during training.
""" """
log_suffix = ""
_hparams_section = "Transformers"
_model_config_section = "Model Configuration"
_ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
_ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
_model_config_description = "The configuration of model number {}."
_model_config_description_note = (
"Note that, when cloning this task and running it remotely,"
" the configuration might be applied to another model instead of this one."
" To avoid this, initialize the task externally by calling `Task.init`"
" before the `ClearMLCallback` is instantiated."
)
_train_run_counter = 0
_model_connect_counter = 0
_task_created_in_callback = False
_should_close_on_train_end = None
def __init__(self): def __init__(self):
if is_clearml_available(): if is_clearml_available():
import clearml import clearml
...@@ -1447,25 +1465,38 @@ class ClearMLCallback(TrainerCallback): ...@@ -1447,25 +1465,38 @@ class ClearMLCallback(TrainerCallback):
raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.") raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")
self._initialized = False self._initialized = False
self._initialized_externally = False
self._clearml_task = None self._clearml_task = None
self._log_model = os.getenv("CLEARML_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) self._log_model = False
self._checkpoints_saved = []
def setup(self, args, state, model, tokenizer, **kwargs): def setup(self, args, state, model, tokenizer, **kwargs):
if self._clearml is None: if self._clearml is None:
return return
if self._initialized: if self._initialized:
return return
ClearMLCallback._train_run_counter += 1
ClearMLCallback._model_connect_counter += 1
ClearMLCallback.log_suffix = (
"" if ClearMLCallback._train_run_counter == 1 else "_" + str(ClearMLCallback._train_run_counter)
)
if state.is_world_process_zero: if state.is_world_process_zero:
logger.info("Automatic ClearML logging enabled.") logger.info("Automatic ClearML logging enabled.")
if self._clearml_task is None: if self._clearml_task is None:
if ClearMLCallback._should_close_on_train_end is None:
if not self._clearml.Task.running_locally() or self._clearml.Task.current_task():
ClearMLCallback._should_close_on_train_end = False
else:
ClearMLCallback._should_close_on_train_end = True
# This might happen when running inside of a pipeline, where the task is already initialized # This might happen when running inside of a pipeline, where the task is already initialized
# from outside of Hugging Face # from outside of Hugging Face
if self._clearml.Task.current_task(): if self._clearml.Task.running_locally() and self._clearml.Task.current_task():
self._clearml_task = self._clearml.Task.current_task() self._clearml_task = self._clearml.Task.current_task()
self._initialized = True self._log_model = os.getenv(
self._initialized_externally = True "CLEARML_LOG_MODEL",
"FALSE" if not ClearMLCallback._task_created_in_callback else "TRUE",
).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
logger.info("External ClearML Task has been connected.") logger.info("External ClearML Task has been connected.")
else: else:
self._clearml_task = self._clearml.Task.init( self._clearml_task = self._clearml.Task.init(
...@@ -1474,27 +1505,83 @@ class ClearMLCallback(TrainerCallback): ...@@ -1474,27 +1505,83 @@ class ClearMLCallback(TrainerCallback):
auto_connect_frameworks={"tensorboard": False, "pytorch": False}, auto_connect_frameworks={"tensorboard": False, "pytorch": False},
output_uri=True, output_uri=True,
) )
self._initialized = True self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
{"TRUE"}
)
ClearMLCallback._task_created_in_callback = True
logger.info("ClearML Task has been initialized.") logger.info("ClearML Task has been initialized.")
self._initialized = True
suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides
if self._clearml.Task.running_locally():
self._copy_training_args_as_hparams(args, suffixed_hparams_section)
self._clearml_task.set_parameter(
name=ignore_hparams_config_section,
value=True,
value_type=bool,
description=(
"If True, ignore Transformers hyperparameters overrides done in the UI/backend "
+ "when running remotely. Otherwise, the overrides will be applied when running remotely"
),
)
elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
self._clearml_task.connect(args, suffixed_hparams_section)
else:
self._copy_training_args_as_hparams(
args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
)
self._clearml_task.connect(args, "Args") if getattr(model, "config", None) is not None:
if hasattr(model, "config") and model.config is not None: ignore_model_config_section = (
self._clearml_task.connect(model.config, "Model Configuration") suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides
)
configuration_object_description = ClearMLCallback._model_config_description.format(
ClearMLCallback._model_connect_counter
)
if ClearMLCallback._model_connect_counter != ClearMLCallback._train_run_counter:
configuration_object_description += " " + ClearMLCallback._model_config_description_note
if self._clearml.Task.running_locally():
self._clearml_task.set_parameter(
name=ignore_model_config_section,
value=True,
value_type=bool,
description=(
"If True, ignore Transformers model configuration overrides done in the UI/backend "
+ "when running remotely. Otherwise, the overrides will be applied when running remotely"
),
)
self._clearml_task.set_configuration_object(
name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
config_dict=model.config.to_dict(),
description=configuration_object_description,
)
elif not self._clearml_task.get_parameter(ignore_model_config_section, default=True, cast=True):
model.config = model.config.from_dict(
self._clearml_task.get_configuration_object_as_dict(
ClearMLCallback._model_config_section + ClearMLCallback.log_suffix
)
)
else:
self._clearml_task.set_configuration_object(
name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
config_dict=model.config.to_dict(),
description=configuration_object_description,
)
def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs): def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._clearml is None: if self._clearml is None:
return return
self._checkpoints_saved = []
if state.is_hyper_param_search: if state.is_hyper_param_search:
self._initialized = False self._initialized = False
if not self._initialized: if not self._initialized:
self.setup(args, state, model, tokenizer, **kwargs) self.setup(args, state, model, tokenizer, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, metrics=None, logs=None, **kwargs): def on_train_end(self, args, state, control, **kwargs):
if self._clearml is None: if ClearMLCallback._should_close_on_train_end:
return
if self._clearml_task and state.is_world_process_zero and not self._initialized_externally:
# Close ClearML Task at the end end of training
self._clearml_task.close() self._clearml_task.close()
ClearMLCallback._train_run_counter = 0
def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
if self._clearml is None: if self._clearml is None:
...@@ -1517,18 +1604,29 @@ class ClearMLCallback(TrainerCallback): ...@@ -1517,18 +1604,29 @@ class ClearMLCallback(TrainerCallback):
for k, v in logs.items(): for k, v in logs.items():
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
if k in single_value_scalars: if k in single_value_scalars:
self._clearml_task.get_logger().report_single_value(name=k, value=v) self._clearml_task.get_logger().report_single_value(
name=k + ClearMLCallback.log_suffix, value=v
)
elif k.startswith(eval_prefix): elif k.startswith(eval_prefix):
self._clearml_task.get_logger().report_scalar( self._clearml_task.get_logger().report_scalar(
title=k[eval_prefix_len:], series="eval", value=v, iteration=state.global_step title="eval" + ClearMLCallback.log_suffix,
series=k[eval_prefix_len:],
value=v,
iteration=state.global_step,
) )
elif k.startswith(test_prefix): elif k.startswith(test_prefix):
self._clearml_task.get_logger().report_scalar( self._clearml_task.get_logger().report_scalar(
title=k[test_prefix_len:], series="test", value=v, iteration=state.global_step title="test" + ClearMLCallback.log_suffix,
series=k[test_prefix_len:],
value=v,
iteration=state.global_step,
) )
else: else:
self._clearml_task.get_logger().report_scalar( self._clearml_task.get_logger().report_scalar(
title=k, series="train", value=v, iteration=state.global_step title="train" + ClearMLCallback.log_suffix,
series=k,
value=v,
iteration=state.global_step,
) )
else: else:
logger.warning( logger.warning(
...@@ -1542,8 +1640,42 @@ class ClearMLCallback(TrainerCallback): ...@@ -1542,8 +1640,42 @@ class ClearMLCallback(TrainerCallback):
if self._log_model and self._clearml_task and state.is_world_process_zero: if self._log_model and self._clearml_task and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}" ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir) artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.") name = ckpt_dir + ClearMLCallback.log_suffix
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False) logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
output_model.connect(task=self._clearml_task, name=name)
output_model.update_weights_package(
weights_path=artifact_path,
target_filename=ckpt_dir,
iteration=state.global_step,
auto_delete_file=False,
)
self._checkpoints_saved.append(output_model)
while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
try:
self._clearml.model.Model.remove(
self._checkpoints_saved[0],
delete_weights_file=True,
force=True,
raise_on_errors=True,
)
except Exception as e:
logger.warning(
"Could not remove checkpoint `{}` after going over the `save_total_limit`. Error is: {}".format(
self._checkpoints_saved[0].name, e
)
)
break
self._checkpoints_saved = self._checkpoints_saved[1:]
def _copy_training_args_as_hparams(self, training_args, prefix):
as_dict = {
field.name: getattr(training_args, field.name)
for field in fields(training_args)
if field.init and not field.name.endswith("_token")
}
flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)
class FlyteCallback(TrainerCallback): class FlyteCallback(TrainerCallback):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment