Commit 858b1d1e authored by jaymody's avatar jaymody Committed by Julien Chaumond
Browse files

allow an already created tensorboard SummaryWriter be passed to Trainer

parent 8e67573a
...@@ -123,6 +123,7 @@ class Trainer: ...@@ -123,6 +123,7 @@ class Trainer:
eval_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False, prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None,
): ):
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, Trainer is a simple but feature-complete training and eval loop for PyTorch,
...@@ -142,7 +143,9 @@ class Trainer: ...@@ -142,7 +143,9 @@ class Trainer:
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only self.prediction_loss_only = prediction_loss_only
if is_tensorboard_available() and self.args.local_rank in [-1, 0]: if tb_writer is not None:
self.tb_writer = tb_writer
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available(): if not is_tensorboard_available():
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment