Unverified Commit 6b660d5e authored by Bharat Ramanathan's avatar Bharat Ramanathan Committed by GitHub
Browse files

Fix: handle logging of scalars in Weights & Biases summary (#29612)

fix: handle logging of scalars in wandb summary

fixes:  #29430
parent 8e64ba28
......@@ -802,13 +802,25 @@ class WandbCallback(TrainerCallback):
self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"train_loss",
"total_flos",
]
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
self._wandb.log({**logs, "train/global_step": state.global_step})
for k, v in logs.items():
if k in single_value_scalars:
self._wandb.run.summary[k] = v
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
non_scalar_logs = rewrite_logs(non_scalar_logs)
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment