Unverified Commit 7c6d6329 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[traner] fix --lr_scheduler_type choices (#9800)



* fix --lr_scheduler_type choices

* rewrite to fix for all enum-based cl args

* cleanup

* adjust test

* style

* Proposal that should work

* Remove needless code

* Fix test
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent 893120fa
...@@ -94,8 +94,8 @@ class HfArgumentParser(ArgumentParser): ...@@ -94,8 +94,8 @@ class HfArgumentParser(ArgumentParser):
field.type = prim_type field.type = prim_type
if isinstance(field.type, type) and issubclass(field.type, Enum): if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type) kwargs["choices"] = [x.value for x in field.type]
kwargs["type"] = field.type kwargs["type"] = type(kwargs["choices"][0])
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]: elif field.type is bool or field.type is Optional[bool]:
...@@ -198,7 +198,7 @@ class HfArgumentParser(ArgumentParser): ...@@ -198,7 +198,7 @@ class HfArgumentParser(ArgumentParser):
data = json.loads(Path(json_file).read_text()) data = json.loads(Path(json_file).read_text())
outputs = [] outputs = []
for dtype in self.dataclass_types: for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)} keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in data.items() if k in keys} inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs) obj = dtype(**inputs)
outputs.append(obj) outputs.append(obj)
...@@ -211,7 +211,7 @@ class HfArgumentParser(ArgumentParser): ...@@ -211,7 +211,7 @@ class HfArgumentParser(ArgumentParser):
""" """
outputs = [] outputs = []
for dtype in self.dataclass_types: for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)} keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in args.items() if k in keys} inputs = {k: v for k, v in args.items() if k in keys}
obj = dtype(**inputs) obj = dtype(**inputs)
outputs.append(obj) outputs.append(obj)
......
...@@ -55,7 +55,10 @@ class BasicEnum(Enum): ...@@ -55,7 +55,10 @@ class BasicEnum(Enum):
@dataclass @dataclass
class EnumExample: class EnumExample:
foo: BasicEnum = BasicEnum.toto foo: BasicEnum = "toto"
def __post_init__(self):
self.foo = BasicEnum(self.foo)
@dataclass @dataclass
...@@ -133,14 +136,18 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -133,14 +136,18 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(EnumExample) parser = HfArgumentParser(EnumExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum) expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
args = parser.parse_args([]) args = parser.parse_args([])
self.assertEqual(args.foo, BasicEnum.toto) self.assertEqual(args.foo, "toto")
enum_ex = parser.parse_args_into_dataclasses([])[0]
self.assertEqual(enum_ex.foo, BasicEnum.toto)
args = parser.parse_args(["--foo", "titi"]) args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, BasicEnum.titi) self.assertEqual(args.foo, "titi")
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
self.assertEqual(enum_ex.foo, BasicEnum.titi)
def test_with_list(self): def test_with_list(self):
parser = HfArgumentParser(ListExample) parser = HfArgumentParser(ListExample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment