"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bf9056442ac58218da7623da2a0f7f4cd02689ad"
Unverified Commit 6b660d5e authored by Bharat Ramanathan's avatar Bharat Ramanathan Committed by GitHub
Browse files

Fix: handle logging of scalars in Weights & Biases summary (#29612)

fix: handle logging of scalars in wandb summary

fixes:  #29430
parent 8e64ba28
...@@ -802,13 +802,25 @@ class WandbCallback(TrainerCallback): ...@@ -802,13 +802,25 @@ class WandbCallback(TrainerCallback):
self._wandb.run.log_artifact(artifact) self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"train_loss",
"total_flos",
]
if self._wandb is None: if self._wandb is None:
return return
if not self._initialized: if not self._initialized:
self.setup(args, state, model) self.setup(args, state, model)
if state.is_world_process_zero: if state.is_world_process_zero:
logs = rewrite_logs(logs) for k, v in logs.items():
self._wandb.log({**logs, "train/global_step": state.global_step}) if k in single_value_scalars:
self._wandb.run.summary[k] = v
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
non_scalar_logs = rewrite_logs(non_scalar_logs)
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs): def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero: if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment