Unverified Commit 1c7ebf1d authored by João Nadkarni's avatar João Nadkarni Committed by GitHub
Browse files

don't log base model architecture in wandb if log model is false (#32143)



* don't log base model architecture in wandb is log model is false

* Update src/transformers/integrations/integration_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* convert log model setting into an enum

* fix formatting

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent c46edfb8
......@@ -26,6 +26,7 @@ import shutil
import sys
import tempfile
from dataclasses import asdict, fields
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
......@@ -726,6 +727,35 @@ def save_model_architecture_to_file(model: Any, output_dir: str):
print(model, file=f)
class WandbLogModel(str, Enum):
"""Enum of possible log model values in W&B."""
CHECKPOINT = "checkpoint"
END = "end"
FALSE = "false"
@property
def is_enabled(self) -> bool:
"""Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
@classmethod
def _missing_(cls, value: Any) -> "WandbLogModel":
if not isinstance(value, str):
raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
if value.upper() in ENV_VARS_TRUE_VALUES:
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
return WandbLogModel.END
logger.warning(
f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
)
return WandbLogModel.FALSE
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
......@@ -740,16 +770,7 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb
self._initialized = False
# log model
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
def setup(self, args, state, model, **kwargs):
"""
......@@ -834,6 +855,7 @@ class WandbCallback(TrainerCallback):
logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model architecture to an artifact
if self._log_model.is_enabled:
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
......@@ -880,7 +902,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
......@@ -938,7 +960,7 @@ class WandbCallback(TrainerCallback):
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment