Unverified Commit 30fa0b78 authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

feat(wandb): save model as artifact (#8119)

* feat(wandb): log artifacts

* fix: typo

* feat(wandb): ensure name is allowed

* feat(wandb): log artifact

* feat(wandb): saving logic

* style: improve formatting

* fix: unrelated typo

* feat: use a fake trainer

* fix: simplify

* feat(wandb): log model files as artifact

* style: fix style

* docs(wandb): correct description

* feat: unpack model + allow env Truethy values

* feat: TrainerCallback can access tokenizer

* style: fix style

* feat(wandb): log more interesting metadata

* feat: unpack tokenizer

* feat(wandb): metadata with load_best_model_at_end

* feat(wandb): more robust metadata

* style(wandb): fix formatting
parent 143289dc
...@@ -15,8 +15,13 @@ ...@@ -15,8 +15,13 @@
Integrations with other Python libraries. Integrations with other Python libraries.
""" """
import math import math
import numbers
import os import os
import re
import tempfile
from pathlib import Path
from .file_utils import ENV_VARS_TRUE_VALUES
from .trainer_utils import EvaluationStrategy from .trainer_utils import EvaluationStrategy
from .utils import logging from .utils import logging
...@@ -369,6 +374,8 @@ class WandbCallback(TrainerCallback): ...@@ -369,6 +374,8 @@ class WandbCallback(TrainerCallback):
<https://docs.wandb.com/huggingface>`__. You can also override the following environment variables: <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment: Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`): WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters. logging or :obj:`"all"` to log gradients and parameters.
...@@ -407,12 +414,44 @@ class WandbCallback(TrainerCallback): ...@@ -407,12 +414,44 @@ class WandbCallback(TrainerCallback):
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
hp_search = state.is_hyper_param_search hp_search = state.is_hyper_param_search
if not self._initialized or hp_search: if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs) self.setup(args, state, model, reinit=hp_search, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
# commit last step
wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
metadata = (
{
k: v
for k, v in dict(wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized: if not self._initialized:
self.setup(args, state, model, reinit=False) self.setup(args, state, model, reinit=False)
......
...@@ -261,7 +261,9 @@ class Trainer: ...@@ -261,7 +261,9 @@ class Trainer:
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler) self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`. # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
......
...@@ -168,6 +168,8 @@ class TrainerCallback: ...@@ -168,6 +168,8 @@ class TrainerCallback:
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions. The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`): model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
The model being trained. The model being trained.
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for encoding the data.
optimizer (:obj:`torch.optim.Optimizer`): optimizer (:obj:`torch.optim.Optimizer`):
The optimizer used for the training steps. The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`): lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
...@@ -274,11 +276,12 @@ class TrainerCallback: ...@@ -274,11 +276,12 @@ class TrainerCallback:
class CallbackHandler(TrainerCallback): class CallbackHandler(TrainerCallback):
""" Internal class that just calls the list of callbacks in order. """ """ Internal class that just calls the list of callbacks in order. """
def __init__(self, callbacks, model, optimizer, lr_scheduler): def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
self.callbacks = [] self.callbacks = []
for cb in callbacks: for cb in callbacks:
self.add_callback(cb) self.add_callback(cb)
self.model = model self.model = model
self.tokenizer = tokenizer
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.train_dataloader = None self.train_dataloader = None
...@@ -376,6 +379,7 @@ class CallbackHandler(TrainerCallback): ...@@ -376,6 +379,7 @@ class CallbackHandler(TrainerCallback):
state, state,
control, control,
model=self.model, model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer, optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler, lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader, train_dataloader=self.train_dataloader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment