Unverified Commit bd0eab35 authored by Teven's avatar Teven Committed by GitHub
Browse files

Trainer + wandb quality of life logging tweaks (#6241)



* added `name` argument for wandb logging, also logging model config with trainer arguments

* Update src/transformers/training_args.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added tf, post-review changes
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 33966811
...@@ -383,7 +383,10 @@ class Trainer: ...@@ -383,7 +383,10 @@ class Trainer:
logger.info( logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
) )
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=self.args.to_sanitized_dict()) combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch( wandb.watch(
......
...@@ -215,7 +215,8 @@ class TFTrainer: ...@@ -215,7 +215,8 @@ class TFTrainer:
return self._setup_wandb() return self._setup_wandb()
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name)
def prediction_loop( def prediction_loop(
self, self,
......
...@@ -109,6 +109,8 @@ class TrainingArguments: ...@@ -109,6 +109,8 @@ class TrainingArguments:
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model ``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -222,6 +224,10 @@ class TrainingArguments: ...@@ -222,6 +224,10 @@ class TrainingArguments:
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
) )
run_name: Optional[str] = field(
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
)
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
""" """
......
...@@ -95,6 +95,8 @@ class TFTrainingArguments(TrainingArguments): ...@@ -95,6 +95,8 @@ class TFTrainingArguments(TrainingArguments):
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`): tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on. The name of the TPU the process is running on.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
""" """
tpu_name: str = field( tpu_name: str = field(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment