"...composable_kernel_rocm.git" did not exist on "951a52b2050bfabc2773160aa427be6644212ad6"
Unverified Commit 07cde58b authored by Bharat Ramanathan's avatar Bharat Ramanathan Committed by GitHub
Browse files

feature: update wandb callback to upload checkpoints (#21035)



* docs: add wandb metrics and model checkpointing to callback docstrings

* docs: update reference to wandb documentation

* fix: change default of `"WANDB_WATCH"` from ``"gradients"` to ``"false"`

* feature: add `on_save` method and update `"WANDB_LOG_MODEL` behaviour

* fix: use default wandb run names instead of `output_dir`

- removes duplicated run names from wandb workspace
- models can be logged with corresponding run names

* fix: edit deprecation warning based on review suggestions
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix: change indentation of docstrings

* fix: change indentation of docstrings and run fixup

* fix: empty commit for circleci permissions issue

* fix: format deprecation doc strings review suggestion
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* docs: Highlight WANDB_DISABLED arg in documentaion
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* fix: run fixup after updating docstrings
Co-authored-by: default avatarBharat Ramanathan <ramanathan.parameshwaran@gohuddl.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a3c37825
......@@ -644,7 +644,7 @@ class TensorBoardCallback(TrainerCallback):
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
"""
def __init__(self):
......@@ -656,28 +656,44 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb
self._initialized = False
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
# log model
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (*wandb*) integration.
One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
[here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
variables:
Environment:
- **WANDB_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
Whether or not to log model as artifact at the end of training. Use along with
[`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
- **WANDB_WATCH** (`str`, *optional*, defaults to `gradients`):
Can be `gradients`, `all` or `false`. Set to `false` to disable gradient logging or `all` to log gradients
and parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `huggingface`):
- **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
<Deprecated version="5.0">
Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
</Deprecated>
- **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
Set this to a custom string to store results in a different project.
- **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
Whether or not to disable wandb entirely. Set `WANDB_DISABLED=True` to disable.
Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
"""
if self._wandb is None:
return
......@@ -694,15 +710,16 @@ class WandbCallback(TrainerCallback):
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["name"] = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
if not (args.run_name is None or args.run_name == args.output_dir):
init_args["name"] = args.run_name
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args,
)
# add config parameters (run may have been created manually)
......@@ -714,10 +731,9 @@ class WandbCallback(TrainerCallback):
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
self._wandb.watch(
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
)
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps))
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
......@@ -733,7 +749,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model and self._initialized and state.is_world_process_zero:
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
from .trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
......@@ -751,7 +767,13 @@ class WandbCallback(TrainerCallback):
"train/total_floss": state.total_flos,
}
)
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
......@@ -767,6 +789,26 @@ class WandbCallback(TrainerCallback):
logs = rewrite_logs(logs)
self._wandb.log({**logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"checkpoint-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
class CometCallback(TrainerCallback):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment