Unverified Commit c9486fd0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix default bool in argparser (#12424)

* Fix default bool in argparser

* Add more to test
parent 90d69456
...@@ -112,8 +112,8 @@ class HfArgumentParser(ArgumentParser): ...@@ -112,8 +112,8 @@ class HfArgumentParser(ArgumentParser):
# Hack because type=bool in argparse does not behave as we want. # Hack because type=bool in argparse does not behave as we want.
kwargs["type"] = string_to_bool kwargs["type"] = string_to_bool
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is True if we have no default when of type bool. # Default value is False if we have no default when of type bool.
default = True if field.default is dataclasses.MISSING else field.default default = False if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way # This is the value that will get picked if we don't include --field_name in any way
kwargs["default"] = default kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name # This tells argparse we accept 0 or 1 value after --field_name
......
...@@ -106,9 +106,13 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -106,9 +106,13 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True) expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True) expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True) expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?") expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?")
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"]
(example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
self.assertFalse(example.flag)
def test_with_default(self): def test_with_default(self):
parser = HfArgumentParser(WithDefaultExample) parser = HfArgumentParser(WithDefaultExample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment