"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2d506ea4c4980a4cab43c2940d9836ddfd629524"
Unverified Commit 07cde58b authored by Bharat Ramanathan's avatar Bharat Ramanathan Committed by GitHub
Browse files

feature: update wandb callback to upload checkpoints (#21035)



* docs: add wandb metrics and model checkpointing to callback docstrings

* docs: update reference to wandb documentation

* fix: change default of `"WANDB_WATCH"` from ``"gradients"` to ``"false"`

* feature: add `on_save` method and update `"WANDB_LOG_MODEL` behaviour

* fix: use default wandb run names instead of `output_dir`

- removes duplicated run names from wandb workspace
- models can be logged with corresponding run names

* fix: edit deprecation warning based on review suggestions
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix: change indentation of docstrings

* fix: change indentation of docstrings and run fixup

* fix: empty commit for circleci permissions issue

* fix: format deprecation doc strings review suggestion
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* docs: Highlight WANDB_DISABLED arg in documentaion
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* fix: run fixup after updating docstrings
Co-authored-by: default avatarBharat Ramanathan <ramanathan.parameshwaran@gohuddl.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a3c37825
...@@ -644,7 +644,7 @@ class TensorBoardCallback(TrainerCallback): ...@@ -644,7 +644,7 @@ class TensorBoardCallback(TrainerCallback):
class WandbCallback(TrainerCallback): class WandbCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/). A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
""" """
def __init__(self): def __init__(self):
...@@ -656,28 +656,44 @@ class WandbCallback(TrainerCallback): ...@@ -656,28 +656,44 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb self._wandb = wandb
self._initialized = False self._initialized = False
# log outputs # log model
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
def setup(self, args, state, model, **kwargs): def setup(self, args, state, model, **kwargs):
""" """
Setup the optional Weights & Biases (*wandb*) integration. Setup the optional Weights & Biases (*wandb*) integration.
One can subclass and override this method to customize the setup if needed. Find more information One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
variables: variables:
Environment: Environment:
- **WANDB_LOG_MODEL** (`bool`, *optional*, defaults to `False`): - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
Whether or not to log model as artifact at the end of training. Use along with Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
[`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model. to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
- **WANDB_WATCH** (`str`, *optional*, defaults to `gradients`): will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
Can be `gradients`, `all` or `false`. Set to `false` to disable gradient logging or `all` to log gradients with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
and parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `huggingface`): <Deprecated version="5.0">
Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
</Deprecated>
- **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
Set this to a custom string to store results in a different project. Set this to a custom string to store results in a different project.
- **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`): - **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
Whether or not to disable wandb entirely. Set `WANDB_DISABLED=True` to disable. Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
""" """
if self._wandb is None: if self._wandb is None:
return return
...@@ -694,15 +710,16 @@ class WandbCallback(TrainerCallback): ...@@ -694,15 +710,16 @@ class WandbCallback(TrainerCallback):
trial_name = state.trial_name trial_name = state.trial_name
init_args = {} init_args = {}
if trial_name is not None: if trial_name is not None:
run_name = trial_name init_args["name"] = trial_name
init_args["group"] = args.run_name init_args["group"] = args.run_name
else: else:
run_name = args.run_name if not (args.run_name is None or args.run_name == args.output_dir):
init_args["name"] = args.run_name
if self._wandb.run is None: if self._wandb.run is None:
self._wandb.init( self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args, **init_args,
) )
# add config parameters (run may have been created manually) # add config parameters (run may have been created manually)
...@@ -714,10 +731,9 @@ class WandbCallback(TrainerCallback): ...@@ -714,10 +731,9 @@ class WandbCallback(TrainerCallback):
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True) self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": _watch_model = os.getenv("WANDB_WATCH", "false")
self._wandb.watch( if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps) self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps))
)
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None: if self._wandb is None:
...@@ -733,7 +749,7 @@ class WandbCallback(TrainerCallback): ...@@ -733,7 +749,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
if self._log_model and self._initialized and state.is_world_process_zero: if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
from .trainer import Trainer from .trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
...@@ -751,7 +767,13 @@ class WandbCallback(TrainerCallback): ...@@ -751,7 +767,13 @@ class WandbCallback(TrainerCallback):
"train/total_floss": state.total_flos, "train/total_floss": state.total_flos,
} }
) )
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata) logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"): for f in Path(temp_dir).glob("*"):
if f.is_file(): if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa: with artifact.new_file(f.name, mode="wb") as fa:
...@@ -767,6 +789,26 @@ class WandbCallback(TrainerCallback): ...@@ -767,6 +789,26 @@ class WandbCallback(TrainerCallback):
logs = rewrite_logs(logs) logs = rewrite_logs(logs)
self._wandb.log({**logs, "train/global_step": state.global_step}) self._wandb.log({**logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"checkpoint-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
class CometCallback(TrainerCallback): class CometCallback(TrainerCallback):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment