"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "969f0d99d333f07dc1f7086214762224c7d5cb6a"
Unverified Commit 3f1714f8 authored by Adam Pocock's avatar Adam Pocock Committed by GitHub
Browse files

Adding required flags to non-default arguments in hf_argparser (#10688)



* Adding required flags to non-default arguments.
Signed-off-by: default avatarAdam Pocock <adam.pocock@oracle.com>

* make style fix.

* Update src/transformers/hf_argparser.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6f840990
...@@ -99,6 +99,8 @@ class HfArgumentParser(ArgumentParser): ...@@ -99,6 +99,8 @@ class HfArgumentParser(ArgumentParser):
kwargs["type"] = type(kwargs["choices"][0]) kwargs["type"] = type(kwargs["choices"][0])
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default kwargs["default"] = field.default
else:
kwargs["required"] = True
elif field.type is bool or field.type == Optional[bool]: elif field.type is bool or field.type == Optional[bool]:
if field.default is True: if field.default is True:
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs) self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
...@@ -124,6 +126,8 @@ class HfArgumentParser(ArgumentParser): ...@@ -124,6 +126,8 @@ class HfArgumentParser(ArgumentParser):
), "{} cannot be a List of mixed types".format(field.name) ), "{} cannot be a List of mixed types".format(field.name)
if field.default_factory is not dataclasses.MISSING: if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory() kwargs["default"] = field.default_factory()
elif field.default is dataclasses.MISSING:
kwargs["required"] = True
else: else:
kwargs["type"] = field.type kwargs["type"] = field.type
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
......
...@@ -78,6 +78,16 @@ class ListExample: ...@@ -78,6 +78,16 @@ class ListExample:
foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3]) foo_float: List[float] = list_field(default=[0.1, 0.2, 0.3])
@dataclass
class RequiredExample:
required_list: List[int] = field()
required_str: str = field()
required_enum: BasicEnum = field()
def __post_init__(self):
self.required_enum = BasicEnum(self.required_enum)
class HfArgumentParserTest(unittest.TestCase): class HfArgumentParserTest(unittest.TestCase):
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool: def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool:
""" """
...@@ -186,6 +196,15 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -186,6 +196,15 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split()) args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3])) self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_with_required(self):
parser = HfArgumentParser(RequiredExample)
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
self.argparsersEqual(parser, expected)
def test_parse_dict(self): def test_parse_dict(self):
parser = HfArgumentParser(BasicExample) parser = HfArgumentParser(BasicExample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment