Unverified Commit 03986609 authored by Dave Berenbaum's avatar Dave Berenbaum Committed by GitHub
Browse files

integrations: fix DVCLiveCallback model logging (#28653)

parent 1fc12960
...@@ -1635,16 +1635,21 @@ class DVCLiveCallback(TrainerCallback): ...@@ -1635,16 +1635,21 @@ class DVCLiveCallback(TrainerCallback):
raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.") raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
from dvclive import Live from dvclive import Live
self._log_model = log_model
self._initialized = False self._initialized = False
self.live = None self.live = None
if isinstance(live, Live): if isinstance(live, Live):
self.live = live self.live = live
self._initialized = True
elif live is not None: elif live is not None:
raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live") raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
self._log_model = log_model
if self._log_model is None:
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
self._log_model = True
elif log_model_env.lower() == "all":
self._log_model = "all"
def setup(self, args, state, model): def setup(self, args, state, model):
""" """
Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
...@@ -1659,12 +1664,6 @@ class DVCLiveCallback(TrainerCallback): ...@@ -1659,12 +1664,6 @@ class DVCLiveCallback(TrainerCallback):
from dvclive import Live from dvclive import Live
self._initialized = True self._initialized = True
if self._log_model is not None:
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL")
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
self._log_model = True
elif log_model_env.lower() == "all":
self._log_model = "all"
if state.is_world_process_zero: if state.is_world_process_zero:
if not self.live: if not self.live:
self.live = Live() self.live = Live()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment