Unverified Commit 0f360d3d authored by krfricke's avatar krfricke Committed by GitHub
Browse files

move wandb/comet logger init to train() to allow parallel logging (#6850)

* move wandb/comet logger init to train() to allow parallel logging

* Setup wandb/comet loggers on first call to log()
parent 39ed68d5
...@@ -255,20 +255,10 @@ class Trainer: ...@@ -255,20 +255,10 @@ class Trainer:
logger.warning( logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
) )
if is_wandb_available():
self.setup_wandb() # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
elif os.environ.get("WANDB_DISABLED") != "true": self._loggers_initialized = False
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
if is_comet_available():
self.setup_comet()
elif os.environ.get("COMET_MODE") != "DISABLED":
logger.info(
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
# Create output directory if needed # Create output directory if needed
if self.is_world_process_zero(): if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True) os.makedirs(self.args.output_dir, exist_ok=True)
...@@ -518,6 +508,25 @@ class Trainer: ...@@ -518,6 +508,25 @@ class Trainer:
""" """
return len(dataloader.dataset) return len(dataloader.dataset)
def _setup_loggers(self):
if self._loggers_initialized:
return
if is_wandb_available():
self.setup_wandb()
elif os.environ.get("WANDB_DISABLED") != "true":
logger.info(
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
if is_comet_available():
self.setup_comet()
elif os.environ.get("COMET_MODE") != "DISABLED":
logger.info(
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
self._loggers_initialized = True
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """ """ HP search setup code """
if self.hp_search_backend is None or trial is None: if self.hp_search_backend is None or trial is None:
...@@ -903,6 +912,9 @@ class Trainer: ...@@ -903,6 +912,9 @@ class Trainer:
iterator (:obj:`tqdm`, `optional`): iterator (:obj:`tqdm`, `optional`):
A potential tqdm progress bar to write the logs on. A potential tqdm progress bar to write the logs on.
""" """
# Set up loggers like W&B or Comet ML
self._setup_loggers()
if hasattr(self, "_log"): if hasattr(self, "_log"):
warnings.warn( warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.", "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment