Unverified Commit 1c7ebf1d authored by João Nadkarni's avatar João Nadkarni Committed by GitHub
Browse files

don't log base model architecture in wandb if log model is false (#32143)



* don't log base model architecture in wandb is log model is false

* Update src/transformers/integrations/integration_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* convert log model setting into an enum

* fix formatting

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent c46edfb8
...@@ -26,6 +26,7 @@ import shutil ...@@ -26,6 +26,7 @@ import shutil
import sys import sys
import tempfile import tempfile
from dataclasses import asdict, fields from dataclasses import asdict, fields
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
...@@ -726,6 +727,35 @@ def save_model_architecture_to_file(model: Any, output_dir: str): ...@@ -726,6 +727,35 @@ def save_model_architecture_to_file(model: Any, output_dir: str):
print(model, file=f) print(model, file=f)
class WandbLogModel(str, Enum):
"""Enum of possible log model values in W&B."""
CHECKPOINT = "checkpoint"
END = "end"
FALSE = "false"
@property
def is_enabled(self) -> bool:
"""Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
@classmethod
def _missing_(cls, value: Any) -> "WandbLogModel":
if not isinstance(value, str):
raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
if value.upper() in ENV_VARS_TRUE_VALUES:
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
return WandbLogModel.END
logger.warning(
f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
)
return WandbLogModel.FALSE
class WandbCallback(TrainerCallback): class WandbCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
...@@ -740,16 +770,7 @@ class WandbCallback(TrainerCallback): ...@@ -740,16 +770,7 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb self._wandb = wandb
self._initialized = False self._initialized = False
# log model self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
def setup(self, args, state, model, **kwargs): def setup(self, args, state, model, **kwargs):
""" """
...@@ -834,6 +855,7 @@ class WandbCallback(TrainerCallback): ...@@ -834,6 +855,7 @@ class WandbCallback(TrainerCallback):
logger.info("Could not log the number of model parameters in Weights & Biases.") logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model architecture to an artifact # log the initial model architecture to an artifact
if self._log_model.is_enabled:
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
model_name = ( model_name = (
f"model-{self._wandb.run.id}" f"model-{self._wandb.run.id}"
...@@ -880,7 +902,7 @@ class WandbCallback(TrainerCallback): ...@@ -880,7 +902,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero: if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
...@@ -938,7 +960,7 @@ class WandbCallback(TrainerCallback): ...@@ -938,7 +960,7 @@ class WandbCallback(TrainerCallback):
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step}) self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs): def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero: if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = { checkpoint_metadata = {
k: v k: v
for k, v in dict(self._wandb.summary).items() for k, v in dict(self._wandb.summary).items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment