Unverified Commit 838dc06f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

parse arguments from dict (#4869)

* add parse_dict to parse arguments from dict

* add unit test for parse_dict
parent cf3cf304
......@@ -158,3 +158,16 @@ class HfArgumentParser(ArgumentParser):
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all,
instead uses a dict and populating the dataclass types.
"""
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in args.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
......@@ -152,6 +152,20 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_parse_dict(self):
parser = HfArgumentParser(BasicExample)
args_dict = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
parsed_args = parser.parse_dict(args_dict)[0]
args = BasicExample(**args_dict)
self.assertEqual(parsed_args, args)
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment