"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1690094bdb8f606685e46f26c17ddadb925d0a20"
Unverified Commit 86387fe8 authored by Felix Schneider's avatar Felix Schneider Committed by GitHub
Browse files

Add an option to `HfArgumentParser.parse_{dict,json_file}` to raise an...

Add an option to `HfArgumentParser.parse_{dict,json_file}` to raise an Exception when there extra keys (#18692)

* Update parser to track unneeded keys, off by default

* Fix formatting

* Fix docstrings and defaults in HfArgparser

* Fix formatting
parent f210e2a4
......@@ -234,29 +234,60 @@ class HfArgumentParser(ArgumentParser):
return (*outputs,)
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
dataclass types.
Args:
json_file (`str` or `os.PathLike`):
File name of the json file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
data = json.loads(Path(json_file).read_text())
unused_keys = set(data.keys())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in data.items() if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
types.
Args:
args (`dict`):
dict containing config values
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
unused_keys = set(args.keys())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in args.items() if k in keys}
unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
......@@ -245,6 +245,19 @@ class HfArgumentParserTest(unittest.TestCase):
args = BasicExample(**args_dict)
self.assertEqual(parsed_args, args)
def test_parse_dict_extra_key(self):
parser = HfArgumentParser(BasicExample)
args_dict = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
"extra": 42,
}
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment