Unverified Commit 1e3f17b5 authored by Konstantin Dobler's avatar Konstantin Dobler Committed by GitHub
Browse files

Enhance HfArgumentParser functionality and ease of use (#20323)

* Enhance HfArgumentParser

* Fix type hints for older python versions

* Fix and add tests (+formatting)

* Add changes

* doc-builder formatting

* Remove unused import "Call"
parent 96783e53
...@@ -20,11 +20,19 @@ from copy import copy ...@@ -20,11 +20,19 @@ from copy import copy
from enum import Enum from enum import Enum
from inspect import isclass from inspect import isclass
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Tuple, Union, get_type_hints
import yaml import yaml
try:
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
from typing import Literal
except ImportError:
# For Python 3.7
from typing_extensions import Literal
DataClass = NewType("DataClass", Any) DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any) DataClassType = NewType("DataClassType", Any)
...@@ -43,6 +51,68 @@ def string_to_bool(v): ...@@ -43,6 +51,68 @@ def string_to_bool(v):
) )
def make_choice_type_function(choices: list) -> Callable[[str], Any]:
"""
Creates a mapping function from each choices string representation to the actual value. Used to support multiple
value types for a single argument.
Args:
choices (list): List of choices.
Returns:
Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
"""
str_to_choice = {str(choice): choice for choice in choices}
return lambda arg: str_to_choice.get(arg, arg)
def HfArg(
*,
aliases: Union[str, List[str]] = None,
help: str = None,
default: Any = dataclasses.MISSING,
default_factory: Callable[[], Any] = dataclasses.MISSING,
metadata: dict = None,
**kwargs,
) -> dataclasses.Field:
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
Example comparing the use of `HfArg` and `dataclasses.field`:
```
@dataclass
class Args:
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
```
Args:
aliases (Union[str, List[str]], optional):
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
Defaults to None.
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
default (Any, optional):
Default value for the argument. If not default or default_factory is specified, the argument is required.
Defaults to dataclasses.MISSING.
default_factory (Callable[[], Any], optional):
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
Defaults to dataclasses.MISSING.
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
Returns:
Field: A `dataclasses.Field` with the desired properties.
"""
if metadata is None:
# Important, don't use as default param in function signature because dict is mutable and shared across function calls
metadata = {}
if aliases is not None:
metadata["aliases"] = aliases
if help is not None:
metadata["help"] = help
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
class HfArgumentParser(ArgumentParser): class HfArgumentParser(ArgumentParser):
""" """
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
...@@ -84,6 +154,10 @@ class HfArgumentParser(ArgumentParser): ...@@ -84,6 +154,10 @@ class HfArgumentParser(ArgumentParser):
"`typing.get_type_hints` method by default" "`typing.get_type_hints` method by default"
) )
aliases = kwargs.pop("aliases", [])
if isinstance(aliases, str):
aliases = [aliases]
origin_type = getattr(field.type, "__origin__", field.type) origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is Union: if origin_type is Union:
if str not in field.type.__args__ and ( if str not in field.type.__args__ and (
...@@ -108,9 +182,14 @@ class HfArgumentParser(ArgumentParser): ...@@ -108,9 +182,14 @@ class HfArgumentParser(ArgumentParser):
# A variable to store kwargs for a boolean field, if needed # A variable to store kwargs for a boolean field, if needed
# so that we can init a `no_*` complement argument (see below) # so that we can init a `no_*` complement argument (see below)
bool_kwargs = {} bool_kwargs = {}
if isinstance(field.type, type) and issubclass(field.type, Enum): if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
if origin_type is Literal:
kwargs["choices"] = field.type.__args__
else:
kwargs["choices"] = [x.value for x in field.type] kwargs["choices"] = [x.value for x in field.type]
kwargs["type"] = type(kwargs["choices"][0])
kwargs["type"] = make_choice_type_function(kwargs["choices"])
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default kwargs["default"] = field.default
else: else:
...@@ -146,7 +225,7 @@ class HfArgumentParser(ArgumentParser): ...@@ -146,7 +225,7 @@ class HfArgumentParser(ArgumentParser):
kwargs["default"] = field.default_factory() kwargs["default"] = field.default_factory()
else: else:
kwargs["required"] = True kwargs["required"] = True
parser.add_argument(field_name, **kwargs) parser.add_argument(field_name, *aliases, **kwargs)
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
# Order is important for arguments with the same destination! # Order is important for arguments with the same destination!
...@@ -178,7 +257,12 @@ class HfArgumentParser(ArgumentParser): ...@@ -178,7 +257,12 @@ class HfArgumentParser(ArgumentParser):
self._parse_dataclass_field(parser, field) self._parse_dataclass_field(parser, field)
def parse_args_into_dataclasses( def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None self,
args=None,
return_remaining_strings=False,
look_for_args_file=True,
args_filename=None,
args_file_flag=None,
) -> Tuple[DataClass, ...]: ) -> Tuple[DataClass, ...]:
""" """
Parse command-line args into instances of the specified dataclass types. Parse command-line args into instances of the specified dataclass types.
...@@ -196,6 +280,9 @@ class HfArgumentParser(ArgumentParser): ...@@ -196,6 +280,9 @@ class HfArgumentParser(ArgumentParser):
process, and will append its potential content to the command line args. process, and will append its potential content to the command line args.
args_filename: args_filename:
If not None, will uses this file instead of the ".args" file specified in the previous argument. If not None, will uses this file instead of the ".args" file specified in the previous argument.
args_file_flag:
If not None, will look for a file in the command-line args specified with this flag. The flag can be
specified multiple times and precedence is determined by the order (last one wins).
Returns: Returns:
Tuple consisting of: Tuple consisting of:
...@@ -205,17 +292,36 @@ class HfArgumentParser(ArgumentParser): ...@@ -205,17 +292,36 @@ class HfArgumentParser(ArgumentParser):
after initialization. after initialization.
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
""" """
if args_filename or (look_for_args_file and len(sys.argv)):
if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
args_files = []
if args_filename: if args_filename:
args_file = Path(args_filename) args_files.append(Path(args_filename))
else: elif look_for_args_file and len(sys.argv):
args_file = Path(sys.argv[0]).with_suffix(".args") args_files.append(Path(sys.argv[0]).with_suffix(".args"))
# args files specified via command line flag should overwrite default args files so we add them last
if args_file_flag:
# Create special parser just to extract the args_file_flag values
args_file_parser = ArgumentParser()
args_file_parser.add_argument(args_file_flag, type=str, action="append")
# Use only remaining args for further parsing (remove the args_file_flag)
cfg, args = args_file_parser.parse_known_args(args=args)
cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
if cmd_args_file_paths:
args_files.extend([Path(p) for p in cmd_args_file_paths])
file_args = []
for args_file in args_files:
if args_file.exists(): if args_file.exists():
fargs = args_file.read_text().split() file_args += args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence # in case of duplicate arguments the last one has precedence
# so we append rather than prepend. # args specified via the command line should overwrite args from files, so we add them last
args = file_args + args if args is not None else file_args + sys.argv[1:]
namespace, remaining_args = self.parse_known_args(args=args) namespace, remaining_args = self.parse_known_args(args=args)
outputs = [] outputs = []
for dtype in self.dataclass_types: for dtype in self.dataclass_types:
......
...@@ -25,7 +25,15 @@ from typing import List, Optional ...@@ -25,7 +25,15 @@ from typing import List, Optional
import yaml import yaml
from transformers import HfArgumentParser, TrainingArguments from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool from transformers.hf_argparser import make_choice_type_function, string_to_bool
try:
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
from typing import Literal
except ImportError:
# For Python 3.7
from typing_extensions import Literal
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
...@@ -58,6 +66,12 @@ class BasicEnum(Enum): ...@@ -58,6 +66,12 @@ class BasicEnum(Enum):
toto = "toto" toto = "toto"
class MixedTypeEnum(Enum):
titi = "titi"
toto = "toto"
fourtytwo = 42
@dataclass @dataclass
class EnumExample: class EnumExample:
foo: BasicEnum = "toto" foo: BasicEnum = "toto"
...@@ -66,6 +80,14 @@ class EnumExample: ...@@ -66,6 +80,14 @@ class EnumExample:
self.foo = BasicEnum(self.foo) self.foo = BasicEnum(self.foo)
@dataclass
class MixedTypeEnumExample:
foo: MixedTypeEnum = "toto"
def __post_init__(self):
self.foo = MixedTypeEnum(self.foo)
@dataclass @dataclass
class OptionalExample: class OptionalExample:
foo: Optional[int] = None foo: Optional[int] = None
...@@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase):
for x, y in zip(a._actions, b._actions): for x, y in zip(a._actions, b._actions):
xx = {k: v for k, v in vars(x).items() if k != "container"} xx = {k: v for k, v in vars(x).items() if k != "container"}
yy = {k: v for k, v in vars(y).items() if k != "container"} yy = {k: v for k, v in vars(y).items() if k != "container"}
# Choices with mixed type have custom function as "type"
# So we need to compare results directly for equality
if xx.get("choices", None) and yy.get("choices", None):
for expected_choice in yy["choices"] + xx["choices"]:
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
del xx["type"], yy["type"]
self.assertEqual(xx, yy) self.assertEqual(xx, yy)
def test_basic(self): def test_basic(self):
...@@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False)) self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self): def test_with_enum(self):
parser = HfArgumentParser(EnumExample) parser = HfArgumentParser(MixedTypeEnumExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str) expected.add_argument(
"--foo",
default="toto",
choices=["titi", "toto", 42],
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
args = parser.parse_args([]) args = parser.parse_args([])
self.assertEqual(args.foo, "toto") self.assertEqual(args.foo, "toto")
enum_ex = parser.parse_args_into_dataclasses([])[0] enum_ex = parser.parse_args_into_dataclasses([])[0]
self.assertEqual(enum_ex.foo, BasicEnum.toto) self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
args = parser.parse_args(["--foo", "titi"]) args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi") self.assertEqual(args.foo, "titi")
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0] enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
self.assertEqual(enum_ex.foo, BasicEnum.titi) self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
def test_with_literal(self):
@dataclass
class LiteralExample:
foo: Literal["titi", "toto", 42] = "toto"
parser = HfArgumentParser(LiteralExample)
expected = argparse.ArgumentParser()
expected.add_argument(
"--foo",
default="toto",
choices=("titi", "toto", 42),
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
def test_with_list(self): def test_with_list(self):
parser = HfArgumentParser(ListExample) parser = HfArgumentParser(ListExample)
...@@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase):
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True) expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True) expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True) expected.add_argument(
"--required_enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
def test_with_string_literal_annotation(self): def test_with_string_literal_annotation(self):
...@@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase):
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=int, required=True) expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True) expected.add_argument(
"--required_enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
expected.add_argument("--opt", type=string_to_bool, default=None) expected.add_argument("--opt", type=string_to_bool, default=None)
expected.add_argument("--baz", default="toto", type=str, help="help message") expected.add_argument("--baz", default="toto", type=str, help="help message")
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str) expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment