Unverified Commit 24175910 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

(v2) Improvements to the wandb integration (#4324)



* Improvements to the wandb integration

* small reorg + no global necessary

* feat(trainer): log epoch and final metrics

* Simplify logging a bit

* Fixup

* Fix crash when just running eval
Co-authored-by: default avatarChris Van Pelt <vanpelt@gmail.com>
Co-authored-by: default avatarBoris Dayma <boris.dayma@gmail.com>
parent 7d7fe499
...@@ -265,7 +265,7 @@ def main(): ...@@ -265,7 +265,7 @@ def main():
eval_output = trainer.evaluate() eval_output = trainer.evaluate()
perplexity = math.exp(eval_output["loss"]) perplexity = math.exp(eval_output["eval_loss"])
result = {"perplexity": perplexity} result = {"perplexity": perplexity}
output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
......
...@@ -72,7 +72,7 @@ class ExamplesTests(unittest.TestCase): ...@@ -72,7 +72,7 @@ class ExamplesTests(unittest.TestCase):
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_glue.main() result = run_glue.main()
del result["loss"] del result["eval_loss"]
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.75)
......
...@@ -6,7 +6,7 @@ from unittest.mock import patch ...@@ -6,7 +6,7 @@ from unittest.mock import patch
import run_ner import run_ner
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger() logger = logging.getLogger()
...@@ -30,4 +30,4 @@ class ExamplesTests(unittest.TestCase): ...@@ -30,4 +30,4 @@ class ExamplesTests(unittest.TestCase):
""".split() """.split()
with patch.object(sys, "argv", ["run.py"] + testargs): with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main() result = run_ner.main()
self.assertLess(result["loss"], 1.5) self.assertLess(result["eval_loss"], 1.5)
...@@ -61,7 +61,12 @@ def is_tensorboard_available(): ...@@ -61,7 +61,12 @@ def is_tensorboard_available():
try: try:
import wandb import wandb
_has_wandb = True wandb.ensure_configured()
if wandb.api.api_key is None:
_has_wandb = False
wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
else:
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
except ImportError: except ImportError:
_has_wandb = False _has_wandb = False
...@@ -114,6 +119,8 @@ class Trainer: ...@@ -114,6 +119,8 @@ class Trainer:
prediction_loss_only: bool prediction_loss_only: bool
tb_writer: Optional["SummaryWriter"] = None tb_writer: Optional["SummaryWriter"] = None
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
global_step: Optional[int] = None
epoch: Optional[float] = None
def __init__( def __init__(
self, self,
...@@ -154,9 +161,12 @@ class Trainer: ...@@ -154,9 +161,12 @@ class Trainer:
logger.warning( logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
) )
if not is_wandb_available(): if is_wandb_available():
self._setup_wandb()
else:
logger.info( logger.info(
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging." "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
) )
set_seed(self.args.seed) set_seed(self.args.seed)
# Create output directory if needed # Create output directory if needed
...@@ -263,11 +273,25 @@ class Trainer: ...@@ -263,11 +273,25 @@ class Trainer:
""" """
Setup the optional Weights & Biases (`wandb`) integration. Setup the optional Weights & Biases (`wandb`) integration.
One can override this method to customize the setup if needed. One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface
You can also override the following environment variables:
Environment:
WANDB_WATCH:
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
or "all" to log gradients and parameters
WANDB_PROJECT:
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
WANDB_DISABLED:
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
""" """
wandb.init(name=self.args.logging_dir, config=vars(self.args)) logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
# keep track of model topology and gradients # keep track of model topology and gradients
wandb.watch(self.model) if os.getenv("WANDB_WATCH") != "false":
wandb.watch(
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
)
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int: def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
""" """
...@@ -333,8 +357,6 @@ class Trainer: ...@@ -333,8 +357,6 @@ class Trainer:
if self.tb_writer is not None: if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
if is_wandb_available():
self._setup_wandb()
# Train! # Train!
if is_tpu_available(): if is_tpu_available():
...@@ -353,25 +375,26 @@ class Trainer: ...@@ -353,25 +375,26 @@ class Trainer:
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 0 self.global_step = 0
self.epoch = 0
epochs_trained = 0 epochs_trained = 0
steps_trained_in_current_epoch = 0 steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint # Check if continuing training from a checkpoint
if model_path is not None: if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path # set global_step to global_step of last saved checkpoint from model path
try: try:
global_step = int(model_path.split("-")[-1].split("/")[0]) self.global_step = int(model_path.split("-")[-1].split("/")[0])
epochs_trained = global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % ( steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps len(train_dataloader) // self.args.gradient_accumulation_steps
) )
logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step) logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError: except ValueError:
global_step = 0 self.global_step = 0
logger.info(" Starting fine-tuning.") logger.info(" Starting fine-tuning.")
tr_loss = 0.0 tr_loss = 0.0
...@@ -408,34 +431,24 @@ class Trainer: ...@@ -408,34 +431,24 @@ class Trainer:
scheduler.step() scheduler.step()
model.zero_grad() model.zero_grad()
global_step += 1 self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator)
if self.is_local_master(): if self.is_local_master():
if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
global_step == 1 and self.args.logging_first_step self.global_step == 1 and self.args.logging_first_step
): ):
logs = {} logs: Dict[str, float] = {}
if self.args.evaluate_during_training: logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
results = self.evaluate() logs["learning_rate"] = scheduler.get_last_lr()[0]
for key, value in results.items():
eval_key = "eval_{}".format(key)
logs[eval_key] = value
loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps
learning_rate_scalar = scheduler.get_last_lr()[0]
logs["learning_rate"] = learning_rate_scalar
logs["loss"] = loss_scalar
logging_loss = tr_loss logging_loss = tr_loss
if self.tb_writer: self._log(logs)
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, global_step)
if is_wandb_available():
wandb.log(logs, step=global_step)
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) if self.args.evaluate_during_training:
self.evaluate()
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference # In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save. # to the model we want to save.
if hasattr(model, "module"): if hasattr(model, "module"):
...@@ -443,7 +456,9 @@ class Trainer: ...@@ -443,7 +456,9 @@ class Trainer:
else: else:
assert model is self.model assert model is self.model
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") output_dir = os.path.join(
self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
)
self.save_model(output_dir) self.save_model(output_dir)
self._rotate_checkpoints() self._rotate_checkpoints()
...@@ -451,10 +466,10 @@ class Trainer: ...@@ -451,10 +466,10 @@ class Trainer:
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir) logger.info("Saving optimizer and scheduler states to %s", output_dir)
if self.args.max_steps > 0 and global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
if self.args.max_steps > 0 and global_step > self.args.max_steps: if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
train_iterator.close() train_iterator.close()
break break
if self.args.tpu_metrics_debug: if self.args.tpu_metrics_debug:
...@@ -465,7 +480,21 @@ class Trainer: ...@@ -465,7 +480,21 @@ class Trainer:
self.tb_writer.close() self.tb_writer.close()
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(global_step, tr_loss / global_step) return TrainOutput(self.global_step, tr_loss / self.global_step)
def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
if self.epoch is not None:
logs["epoch"] = self.epoch
if self.tb_writer:
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, self.global_step)
if is_wandb_available():
wandb.log(logs, step=self.global_step)
output = json.dumps({**logs, **{"step": self.global_step}})
if iterator is not None:
iterator.write(output)
else:
print(output)
def _training_step( def _training_step(
self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
...@@ -582,6 +611,8 @@ class Trainer: ...@@ -582,6 +611,8 @@ class Trainer:
output = self._prediction_loop(eval_dataloader, description="Evaluation") output = self._prediction_loop(eval_dataloader, description="Evaluation")
self._log(output.metrics)
if self.args.tpu_metrics_debug: if self.args.tpu_metrics_debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report()) xm.master_print(met.metrics_report())
...@@ -663,6 +694,11 @@ class Trainer: ...@@ -663,6 +694,11 @@ class Trainer:
else: else:
metrics = {} metrics = {}
if len(eval_losses) > 0: if len(eval_losses) > 0:
metrics["loss"] = np.mean(eval_losses) metrics["eval_loss"] = np.mean(eval_losses)
# Prefix all keys with eval_
for key in list(metrics.keys()):
if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
...@@ -98,7 +98,7 @@ class TrainerIntegrationTest(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TrainerIntegrationTest(unittest.TestCase):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True) training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
result = trainer.evaluate() result = trainer.evaluate()
self.assertLess(result["loss"], 0.2) self.assertLess(result["eval_loss"], 0.2)
def test_trainer_eval_lm(self): def test_trainer_eval_lm(self):
MODEL_ID = "distilroberta-base" MODEL_ID = "distilroberta-base"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment