Unverified Commit 487505ff authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Allow for str versions of dicts based on typing (#30227)

* Bookmark, initial impelemtation. Need to test

* Clean

* Working fully, woop woop

* I think working version now, testing

* Fin!

* rm cast, could keep None

* Fix typing issue

* rm typehint

* Add test

* Add tests and make more rigid
parent b86d0f4e
...@@ -173,6 +173,37 @@ class OptimizerNames(ExplicitEnum): ...@@ -173,6 +173,37 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"accelerator_config",
"fsdp_config",
"deepspeed",
"gradient_checkpointing_kwargs",
"lr_scheduler_kwargs",
]
def _convert_str_dict(passed_value: dict):
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
for key, value in passed_value.items():
if isinstance(value, dict):
passed_value[key] = _convert_str_dict(value)
elif isinstance(value, str):
# First check for bool and convert
if value.lower() in ("true", "false"):
passed_value[key] = value.lower() == "true"
# Check for digit
elif value.isdigit():
passed_value[key] = int(value)
elif value.replace(".", "", 1).isdigit():
passed_value[key] = float(value)
return passed_value
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
...@@ -803,11 +834,11 @@ class TrainingArguments: ...@@ -803,11 +834,11 @@ class TrainingArguments:
default="linear", default="linear",
metadata={"help": "The scheduler type to use."}, metadata={"help": "The scheduler type to use."},
) )
lr_scheduler_kwargs: Optional[Dict] = field( lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
default_factory=dict, default_factory=dict,
metadata={ metadata={
"help": ( "help": (
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts" "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
) )
}, },
) )
...@@ -1118,7 +1149,6 @@ class TrainingArguments: ...@@ -1118,7 +1149,6 @@ class TrainingArguments:
) )
}, },
) )
# Do not touch this type annotation or it will stop working in CLI
fsdp_config: Optional[Union[dict, str]] = field( fsdp_config: Optional[Union[dict, str]] = field(
default=None, default=None,
metadata={ metadata={
...@@ -1137,8 +1167,7 @@ class TrainingArguments: ...@@ -1137,8 +1167,7 @@ class TrainingArguments:
) )
}, },
) )
# Do not touch this type annotation or it will stop working in CLI accelerator_config: Optional[Union[dict, str]] = field(
accelerator_config: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
...@@ -1147,8 +1176,7 @@ class TrainingArguments: ...@@ -1147,8 +1176,7 @@ class TrainingArguments:
) )
}, },
) )
# Do not touch this type annotation or it will stop working in CLI deepspeed: Optional[Union[dict, str]] = field(
deepspeed: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
...@@ -1252,7 +1280,7 @@ class TrainingArguments: ...@@ -1252,7 +1280,7 @@ class TrainingArguments:
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
}, },
) )
gradient_checkpointing_kwargs: Optional[dict] = field( gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
default=None, default=None,
metadata={ metadata={
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
...@@ -1380,6 +1408,17 @@ class TrainingArguments: ...@@ -1380,6 +1408,17 @@ class TrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
passed_value = getattr(self, field)
# We only want to do this if the str starts with a bracket to indiciate a `dict`
# else its likely a filename if supported
if isinstance(passed_value, str) and passed_value.startswith("{"):
loaded_dict = json.loads(passed_value)
# Convert str values to types if applicable
loaded_dict = _convert_str_dict(loaded_dict)
setattr(self, field, loaded_dict)
# expand paths, if not os.makedirs("~/bar") will make directory # expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home # in the current directory instead of the actual home
# see https://github.com/huggingface/transformers/issues/10628 # see https://github.com/huggingface/transformers/issues/10628
......
...@@ -22,12 +22,14 @@ from argparse import Namespace ...@@ -22,12 +22,14 @@ from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional from typing import Dict, List, Literal, Optional, Union, get_args, get_origin
import yaml import yaml
from transformers import HfArgumentParser, TrainingArguments from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import make_choice_type_function, string_to_bool from transformers.hf_argparser import make_choice_type_function, string_to_bool
from transformers.testing_utils import require_torch
from transformers.training_args import _VALID_DICT_FIELDS
# Since Python 3.10, we can use the builtin `|` operator for Union types # Since Python 3.10, we can use the builtin `|` operator for Union types
...@@ -405,3 +407,68 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -405,3 +407,68 @@ class HfArgumentParserTest(unittest.TestCase):
def test_integration_training_args(self): def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments) parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser) self.assertIsNotNone(parser)
def test_valid_dict_annotation(self):
"""
Tests to make sure that `dict` based annotations
are correctly made in the `TrainingArguments`.
If this fails, a type annotation change is
needed on a new input
"""
base_list = _VALID_DICT_FIELDS.copy()
args = TrainingArguments
# First find any annotations that contain `dict`
fields = args.__dataclass_fields__
raw_dict_fields = []
optional_dict_fields = []
for field in fields.values():
# First verify raw dict
if field.type in (dict, Dict):
raw_dict_fields.append(field)
# Next check for `Union` or `Optional`
elif get_origin(field.type) == Union:
if any(arg in (dict, Dict) for arg in get_args(field.type)):
optional_dict_fields.append(field)
# First check: anything in `raw_dict_fields` is very bad
self.assertEqual(
len(raw_dict_fields),
0,
"Found invalid raw `dict` types in the `TrainingArgument` typings. "
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
)
# Next check raw annotations
for field in optional_dict_fields:
args = get_args(field.type)
# These should be returned as `dict`, `str`, ...
# we only care about the first two
self.assertIn(args[0], (Dict, dict))
self.assertEqual(
str(args[1]),
"<class 'str'>",
f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
"but `str` not found. Please fix this.",
)
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
for field in optional_dict_fields:
self.assertIn(
field.name,
base_list,
f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `training_args._VALID_DICT_FIELDS`",
)
@require_torch
def test_valid_dict_input_parsing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
)
self.assertEqual(args.accelerator_config.split_batches, True)
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment