Unverified Commit a97a73e0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Small QOL improvements to TrainingArguments (#7475)

* Small QOL improvements to TrainingArguments

* With the self.
parent dc7d2daa
...@@ -49,8 +49,9 @@ class TrainingArguments: ...@@ -49,8 +49,9 @@ class TrainingArguments:
:obj:`output_dir` points to a checkpoint directory. :obj:`output_dir` points to a checkpoint directory.
do_train (:obj:`bool`, `optional`, defaults to :obj:`False`): do_train (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run training or not. Whether to run training or not.
do_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): do_eval (:obj:`bool`, `optional`):
Whether to run evaluation on the dev set or not. Whether to run evaluation on the dev set or not. Will default to :obj:`evaluation_strategy` different from
:obj:`"no"`.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not. Whether to run predictions on the test set or not.
evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
...@@ -183,7 +184,7 @@ class TrainingArguments: ...@@ -183,7 +184,7 @@ class TrainingArguments:
) )
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluate_during_training: bool = field( evaluate_during_training: bool = field(
default=None, default=None,
...@@ -333,7 +334,8 @@ class TrainingArguments: ...@@ -333,7 +334,8 @@ class TrainingArguments:
) )
else: else:
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
if self.do_eval is None:
self.do_eval = self.evaluation_strategy != EvaluationStrategy.NO
if self.eval_steps is None: if self.eval_steps is None:
self.eval_steps = self.logging_steps self.eval_steps = self.logging_steps
...@@ -341,6 +343,8 @@ class TrainingArguments: ...@@ -341,6 +343,8 @@ class TrainingArguments:
self.metric_for_best_model = "loss" self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None: if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
if self.run_name is None:
self.run_name = self.output_dir
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment