Unverified Commit 03986609 authored by Dave Berenbaum's avatar Dave Berenbaum Committed by GitHub
Browse files

integrations: fix DVCLiveCallback model logging (#28653)

parent 1fc12960
......@@ -1635,16 +1635,21 @@ class DVCLiveCallback(TrainerCallback):
raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
from dvclive import Live
self._log_model = log_model
self._initialized = False
self.live = None
if isinstance(live, Live):
self.live = live
self._initialized = True
elif live is not None:
raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
self._log_model = log_model
if self._log_model is None:
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
self._log_model = True
elif log_model_env.lower() == "all":
self._log_model = "all"
def setup(self, args, state, model):
"""
Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
......@@ -1659,12 +1664,6 @@ class DVCLiveCallback(TrainerCallback):
from dvclive import Live
self._initialized = True
if self._log_model is not None:
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL")
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
self._log_model = True
elif log_model_env.lower() == "all":
self._log_model = "all"
if state.is_world_process_zero:
if not self.live:
self.live = Live()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment