Unverified Commit 818463ee authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

Trainer: add logging through Weights & Biases (#3916)



* feat: add logging through Weights & Biases

* feat(wandb): make logging compatible with all scripts

* style(trainer.py): fix formatting

* [Trainer] Tweak wandb integration
Co-authored-by: default avatarJulien Chaumond <chaumond@gmail.com>
parent 858b1d1e
...@@ -131,6 +131,7 @@ proc_data ...@@ -131,6 +131,7 @@ proc_data
# examples # examples
runs runs
/runs_old /runs_old
/wandb
examples/runs examples/runs
# data # data
......
...@@ -52,6 +52,18 @@ def is_tensorboard_available(): ...@@ -52,6 +52,18 @@ def is_tensorboard_available():
return _has_tensorboard return _has_tensorboard
try:
import wandb
_has_wandb = True
except ImportError:
_has_wandb = False
def is_wandb_available():
return _has_wandb
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -151,6 +163,10 @@ class Trainer: ...@@ -151,6 +163,10 @@ class Trainer:
logger.warning( logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
) )
if not is_wandb_available():
logger.info(
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
)
set_seed(self.args.seed) set_seed(self.args.seed)
# Create output directory if needed # Create output directory if needed
if self.args.local_rank in [-1, 0]: if self.args.local_rank in [-1, 0]:
...@@ -209,6 +225,12 @@ class Trainer: ...@@ -209,6 +225,12 @@ class Trainer:
) )
return optimizer, scheduler return optimizer, scheduler
def _setup_wandb(self):
# Start a wandb run and log config parameters
wandb.init(name=self.args.logging_dir, config=vars(self.args))
# keep track of model topology and gradients
# wandb.watch(self.model)
def train(self, model_path: Optional[str] = None): def train(self, model_path: Optional[str] = None):
""" """
Main training entry point. Main training entry point.
...@@ -263,6 +285,9 @@ class Trainer: ...@@ -263,6 +285,9 @@ class Trainer:
if self.tb_writer is not None: if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
if is_wandb_available():
self._setup_wandb()
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
...@@ -351,6 +376,9 @@ class Trainer: ...@@ -351,6 +376,9 @@ class Trainer:
if self.tb_writer: if self.tb_writer:
for k, v in logs.items(): for k, v in logs.items():
self.tb_writer.add_scalar(k, v, global_step) self.tb_writer.add_scalar(k, v, global_step)
if is_wandb_available():
wandb.log(logs, step=global_step)
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
...@@ -467,7 +495,7 @@ class Trainer: ...@@ -467,7 +495,7 @@ class Trainer:
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
def evaluate( def evaluate(
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Run evaluation and return metrics. Run evaluation and return metrics.
......
...@@ -2,7 +2,7 @@ import dataclasses ...@@ -2,7 +2,7 @@ import dataclasses
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Tuple from typing import Any, Dict, Optional, Tuple
from .file_utils import cached_property, is_torch_available, torch_required from .file_utils import cached_property, is_torch_available, torch_required
...@@ -138,3 +138,13 @@ class TrainingArguments: ...@@ -138,3 +138,13 @@ class TrainingArguments:
Serializes this instance to a JSON string. Serializes this instance to a JSON string.
""" """
return json.dumps(dataclasses.asdict(self), indent=2) return json.dumps(dataclasses.asdict(self), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]:
"""
Sanitized serialization to use with TensorBoard’s hparams
"""
d = dataclasses.asdict(self)
valid_types = [bool, int, float, str]
if is_torch_available():
valid_types.append(torch.Tensor)
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment