Unverified Commit 4ab7a282 authored by Bharat Ramanathan's avatar Bharat Ramanathan Committed by GitHub
Browse files

feat: Upgrade Weights & Biases callback (#30135)

* feat: upgrade wandb callback with new features

* fix: ci issues with imports and run fixup
parent 30b45320
...@@ -31,8 +31,17 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union ...@@ -31,8 +31,17 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np import numpy as np
import packaging.version import packaging.version
from .. import PreTrainedModel, TFPreTrainedModel
from .. import __version__ as version from .. import __version__ as version
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging from ..utils import (
PushToHubMixin,
flatten_dict,
is_datasets_available,
is_pandas_available,
is_tf_available,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -69,6 +78,7 @@ if TYPE_CHECKING and _has_neptune: ...@@ -69,6 +78,7 @@ if TYPE_CHECKING and _has_neptune:
except importlib.metadata.PackageNotFoundError: except importlib.metadata.PackageNotFoundError:
_has_neptune = False _has_neptune = False
from .. import modelcard # noqa: E402
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402 from ..training_args import ParallelMode # noqa: E402
...@@ -663,6 +673,22 @@ class TensorBoardCallback(TrainerCallback): ...@@ -663,6 +673,22 @@ class TensorBoardCallback(TrainerCallback):
self.tb_writer = None self.tb_writer = None
def save_model_architecture_to_file(model: Any, output_dir: str):
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
if isinstance(model, PreTrainedModel):
print(model, file=f)
elif is_tf_available() and isinstance(model, TFPreTrainedModel):
def print_to_file(s):
print(s, file=f)
model.summary(print_fn=print_to_file)
elif is_torch_available() and (
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
):
print(model, file=f)
class WandbCallback(TrainerCallback): class WandbCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
...@@ -728,6 +754,9 @@ class WandbCallback(TrainerCallback): ...@@ -728,6 +754,9 @@ class WandbCallback(TrainerCallback):
if hasattr(model, "config") and model.config is not None: if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict() model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict} combined_dict = {**model_config, **combined_dict}
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
trial_name = state.trial_name trial_name = state.trial_name
init_args = {} init_args = {}
if trial_name is not None: if trial_name is not None:
...@@ -756,6 +785,51 @@ class WandbCallback(TrainerCallback): ...@@ -756,6 +785,51 @@ class WandbCallback(TrainerCallback):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer") self._wandb.run._label(code="transformers_trainer")
# add number of model parameters to wandb config
if any(
(
isinstance(model, PreTrainedModel),
isinstance(model, PushToHubMixin),
(is_tf_available() and isinstance(model, TFPreTrainedModel)),
(is_torch_available() and isinstance(model, torch.nn.Module)),
)
):
self._wandb.config["model/num_parameters"] = model.num_parameters()
# log the initial model and architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
model.save_pretrained(temp_dir)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
...@@ -786,20 +860,25 @@ class WandbCallback(TrainerCallback): ...@@ -786,20 +860,25 @@ class WandbCallback(TrainerCallback):
else { else {
f"eval/{args.metric_for_best_model}": state.best_metric, f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos, "train/total_floss": state.total_flos,
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
} }
) )
metadata["final_model"] = True
logger.info("Logging model artifacts. ...") logger.info("Logging model artifacts. ...")
model_name = ( model_name = (
f"model-{self._wandb.run.id}" f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir) if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}" else f"model-{self._wandb.run.name}"
) )
# add the model architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata) artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"): for f in Path(temp_dir).glob("*"):
if f.is_file(): if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa: with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes()) fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact) self._wandb.run.log_artifact(artifact, aliases=["final_model"])
def on_log(self, args, state, control, model=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [ single_value_scalars = [
...@@ -829,18 +908,30 @@ class WandbCallback(TrainerCallback): ...@@ -829,18 +908,30 @@ class WandbCallback(TrainerCallback):
for k, v in dict(self._wandb.summary).items() for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_") if isinstance(v, numbers.Number) and not k.startswith("_")
} }
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
ckpt_dir = f"checkpoint-{state.global_step}" ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir) artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...") logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = ( checkpoint_name = (
f"checkpoint-{self._wandb.run.id}" f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir) if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}" else f"model-{self._wandb.run.name}"
) )
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata) artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path) artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"]) self._wandb.log_artifact(
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
)
def on_predict(self, args, state, control, metrics, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, **kwargs)
if state.is_world_process_zero:
metrics = rewrite_logs(metrics)
self._wandb.log(metrics)
class CometCallback(TrainerCallback): class CometCallback(TrainerCallback):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment