Unverified Commit 3ae2e86b authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Run a single wandb instance per TPU run (#4851)



* Run a single wandb instance per TPU run

* wandb: self.is_world_master

* make style
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 466aa57a
...@@ -336,13 +336,16 @@ class Trainer: ...@@ -336,13 +336,16 @@ class Trainer:
WANDB_DISABLED: WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
""" """
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') if self.is_world_master():
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) logger.info(
# keep track of model topology and gradients 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
if os.getenv("WANDB_WATCH") != "false":
wandb.watch(
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
) )
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
# keep track of model topology and gradients
if os.getenv("WANDB_WATCH") != "false":
wandb.watch(
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
...@@ -557,7 +560,8 @@ class Trainer: ...@@ -557,7 +560,8 @@ class Trainer:
self.tb_writer.add_scalar(k, v, self.global_step) self.tb_writer.add_scalar(k, v, self.global_step)
self.tb_writer.flush() self.tb_writer.flush()
if is_wandb_available(): if is_wandb_available():
wandb.log(logs, step=self.global_step) if self.is_world_master():
wandb.log(logs, step=self.global_step)
output = json.dumps({**logs, **{"step": self.global_step}}) output = json.dumps({**logs, **{"step": self.global_step}})
if iterator is not None: if iterator is not None:
iterator.write(output) iterator.write(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment