Unverified Commit bb9559a7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Don't use `store_xxx` on optional bools (#7786)

* Don't use `store_xxx` on optional bools

* Refine test

* Refine test
parent a1d1b332
...@@ -59,7 +59,7 @@ class TorchXLAExamplesTests(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TorchXLAExamplesTests(unittest.TestCase):
--model_name_or_path=bert-base-cased --model_name_or_path=bert-base-cased
--per_device_train_batch_size=64 --per_device_train_batch_size=64
--per_device_eval_batch_size=64 --per_device_eval_batch_size=64
--evaluate_during_training --evaluation_strategy steps
--overwrite_cache --overwrite_cache
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
......
...@@ -43,7 +43,7 @@ python run_tf_text_classification.py \ ...@@ -43,7 +43,7 @@ python run_tf_text_classification.py \
--do_eval \ --do_eval \
--do_predict \ --do_predict \
--logging_steps 10 \ --logging_steps 10 \
--evaluate_during_training \ --evaluation_strategy steps \
--save_steps 10 \ --save_steps 10 \
--overwrite_output_dir \ --overwrite_output_dir \
--max_seq_length 128 --max_seq_length 128
......
...@@ -65,7 +65,8 @@ class HfArgumentParser(ArgumentParser): ...@@ -65,7 +65,8 @@ class HfArgumentParser(ArgumentParser):
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]: elif field.type is bool or field.type is Optional[bool]:
kwargs["action"] = "store_false" if field.default is True else "store_true" if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True: if field.default is True:
field_name = f"--no-{field.name}" field_name = f"--no-{field.name}"
kwargs["dest"] = field.name kwargs["dest"] = field.name
......
...@@ -191,7 +191,7 @@ class TrainingArguments: ...@@ -191,7 +191,7 @@ class TrainingArguments:
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluate_during_training: bool = field( evaluate_during_training: bool = field(
default=None, default=False,
metadata={"help": "Run evaluation during training at each logging step."}, metadata={"help": "Run evaluation during training at each logging step."},
) )
evaluation_strategy: EvaluationStrategy = field( evaluation_strategy: EvaluationStrategy = field(
......
...@@ -85,7 +85,7 @@ ...@@ -85,7 +85,7 @@
pass-as: --output_dir={v} pass-as: --output_dir={v}
type: string type: string
default: /valohai/outputs default: /valohai/outputs
- name: evaluate_during_training - name: evaluation_strategy
description: Run evaluation during training at each logging step. description: The evaluation strategy to use.
type: flag type: string
default: true default: steps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment