Unverified Commit be4f2699 authored by IMvision12's avatar IMvision12 Committed by GitHub
Browse files

Updated hf_argparser.py (#19188)

* Changed json_file_parser function and added yaml parser function

* update hf_argparser

* Added allow_extra_keys argument
parent c20b2c7e
...@@ -22,6 +22,8 @@ from inspect import isclass ...@@ -22,6 +22,8 @@ from inspect import isclass
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
import yaml
DataClass = NewType("DataClass", Any) DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any) DataClassType = NewType("DataClassType", Any)
...@@ -234,29 +236,27 @@ class HfArgumentParser(ArgumentParser): ...@@ -234,29 +236,27 @@ class HfArgumentParser(ArgumentParser):
return (*outputs,) return (*outputs,)
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
""" """
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
dataclass types. types.
Args: Args:
json_file (`str` or `os.PathLike`): args (`dict`):
File name of the json file to parse dict containing config values
allow_extra_keys (`bool`, *optional*, defaults to `False`): allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
parsed.
Returns: Returns:
Tuple consisting of: Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer. - the dataclass instances in the same order as they were passed to the initializer.
""" """
data = json.loads(Path(json_file).read_text()) unused_keys = set(args.keys())
unused_keys = set(data.keys())
outputs = [] outputs = []
for dtype in self.dataclass_types: for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init} keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in data.items() if k in keys} inputs = {k: v for k, v in args.items() if k in keys}
unused_keys.difference_update(inputs.keys()) unused_keys.difference_update(inputs.keys())
obj = dtype(**inputs) obj = dtype(**inputs)
outputs.append(obj) outputs.append(obj)
...@@ -264,30 +264,42 @@ class HfArgumentParser(ArgumentParser): ...@@ -264,30 +264,42 @@ class HfArgumentParser(ArgumentParser):
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs) return tuple(outputs)
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
""" """
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
types. dataclass types.
Args: Args:
args (`dict`): json_file (`str` or `os.PathLike`):
dict containing config values File name of the json file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`): allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns: Returns:
Tuple consisting of: Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer. - the dataclass instances in the same order as they were passed to the initializer.
""" """
unused_keys = set(args.keys()) outputs = self.parse_dict(json.loads(Path(json_file).read_text()), allow_extra_keys=allow_extra_keys)
outputs = [] return tuple(outputs)
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init} def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
inputs = {k: v for k, v in args.items() if k in keys} """
unused_keys.difference_update(inputs.keys()) Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
obj = dtype(**inputs) dataclass types.
outputs.append(obj)
if not allow_extra_keys and unused_keys: Args:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") yaml_file (`str` or `os.PathLike`):
File name of the yaml file to parse
allow_extra_keys (`bool`, *optional*, defaults to `False`):
Defaults to False. If False, will raise an exception if the json file contains keys that are not
parsed.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
"""
outputs = self.parse_dict(yaml.safe_load(yaml_file), allow_extra_keys=allow_extra_keys)
return tuple(outputs) return tuple(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment