"tests/tokenization/test_tokenization_fast.py" did not exist on "b4b562d83415d842c081f571e0ec325f40f276aa"
Unverified Commit 893120fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow --arg Value for booleans in HfArgumentParser (#9823)

* Allow --arg Value for booleans in HfArgumentParser

* Update last test

* Better error message
parent 35d55b7b
......@@ -15,7 +15,7 @@
import dataclasses
import json
import sys
from argparse import ArgumentParser
from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
......@@ -25,6 +25,20 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def string_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
......@@ -85,11 +99,20 @@ class HfArgumentParser(ArgumentParser):
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]:
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no_{field.name}"
kwargs["dest"] = field.name
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
# Hack because type=bool in argparse does not behave as we want.
kwargs["type"] = string_to_bool
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is True if we have no default when of type bool.
default = True if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way
kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name
kwargs["nargs"] = "?"
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]
......
......@@ -20,6 +20,7 @@ from enum import Enum
from typing import List, Optional
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
def list_field(default=None, metadata=None):
......@@ -44,6 +45,7 @@ class WithDefaultExample:
class WithDefaultBoolExample:
foo: bool = False
baz: bool = True
opt: Optional[bool] = None
class BasicEnum(Enum):
......@@ -91,7 +93,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", action="store_true")
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
self.argparsersEqual(parser, expected)
def test_with_default(self):
......@@ -106,15 +108,26 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(WithDefaultBoolExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true")
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--no_baz", action="store_false", dest="baz")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
expected.add_argument("--opt", type=string_to_bool, default=None)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True))
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False))
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self):
parser = HfArgumentParser(EnumExample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment