Unverified Commit b169ac9c authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

[examples] Generate argparsers from type hints on dataclasses (#3669)

* [examples] Generate argparsers from type hints on dataclasses

* [HfArgumentParser] way simpler API

* Restore run_language_modeling.py for easier diff

* [HfArgumentParser] final tweaks from code review
parent 7a7fdf71
...@@ -22,6 +22,8 @@ import json ...@@ -22,6 +22,8 @@ import json
import logging import logging
import os import os
import random import random
from dataclasses import dataclass, field
from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -36,6 +38,8 @@ from transformers import ( ...@@ -36,6 +38,8 @@ from transformers import (
AutoConfig, AutoConfig,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
HfArgumentParser,
TrainingArguments,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from transformers import glue_compute_metrics as compute_metrics from transformers import glue_compute_metrics as compute_metrics
...@@ -376,137 +380,54 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -376,137 +380,54 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
return dataset return dataset
def main(): @dataclass
parser = argparse.ArgumentParser() class ModelArguments:
"""
# Required parameters Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
parser.add_argument( """
"--data_dir",
default=None, model_name_or_path: str = field(
type=str, metadata={"help": "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)}
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
) )
parser.add_argument( model_type: str = field(metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)})
"--model_name_or_path", config_name: Optional[str] = field(
default=None, default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
) )
parser.add_argument( tokenizer_name: Optional[str] = field(
"--task_name", default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
) )
parser.add_argument( cache_dir: Optional[str] = field(
"--output_dir", default=None, metadata={"help": "Where do you want to store the pre-trained models downloaded from s3"}
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
) )
# Other parameters
parser.add_argument( @dataclass
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", class DataProcessingArguments:
) task_name: str = field(
parser.add_argument( metadata={"help": "The name of the task to train selected in the list: " + ", ".join(processors.keys())}
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
) )
parser.add_argument( data_dir: str = field(
"--cache_dir", metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
) )
parser.add_argument( max_seq_length: int = field(
"--max_seq_length",
default=128, default=128,
type=int, metadata={
help="The maximum total input sequence length after tokenization. Sequences longer " "help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", "than this will be truncated, sequences shorter will be padded."
) },
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
) )
parser.add_argument( overwrite_cache: bool = field(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
parser.add_argument(
"--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
)
parser.add_argument(
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument( def main():
"--fp16", parser = HfArgumentParser((ModelArguments, DataProcessingArguments, TrainingArguments))
action="store_true", model_args, dataprocessing_args, training_args = parser.parse_args_into_dataclasses()
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
) # For now, let's merge all the sets of args into one,
parser.add_argument( # but soon, we'll keep distinct sets of args, with a cleaner separation of concerns.
"--fp16_opt_level", args = argparse.Namespace(**vars(model_args), **vars(dataprocessing_args), **vars(training_args))
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args()
if ( if (
os.path.exists(args.output_dir) os.path.exists(args.output_dir)
...@@ -515,20 +436,9 @@ def main(): ...@@ -515,20 +436,9 @@ def main():
and not args.overwrite_output_dir and not args.overwrite_output_dir
): ):
raise ValueError( raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
args.output_dir
)
) )
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training # Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
...@@ -576,18 +486,16 @@ def main(): ...@@ -576,18 +486,16 @@ def main():
args.config_name if args.config_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels, num_labels=num_labels,
finetuning_task=args.task_name, finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
) )
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir,
) )
if args.local_rank == 0: if args.local_rank == 0:
...@@ -629,7 +537,7 @@ def main(): ...@@ -629,7 +537,7 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = list(
......
...@@ -88,6 +88,7 @@ from .file_utils import ( ...@@ -88,6 +88,7 @@ from .file_utils import (
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
) )
from .hf_argparser import HfArgumentParser
# Model Cards # Model Cards
from .modelcard import ModelCard from .modelcard import ModelCard
...@@ -141,6 +142,7 @@ from .tokenization_utils import PreTrainedTokenizer ...@@ -141,6 +142,7 @@ from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .training_args import TrainingArguments
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
......
import dataclasses
from argparse import ArgumentParser
from enum import Enum
from typing import Any, Iterable, NewType, Tuple, Union
DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
to generate arguments.
The class is designed to play well with the native argparse. In particular,
you can add more (non-dataclass backed) arguments to the parser after initialization
and you'll get the output back after parsing as an additional namespace.
"""
dataclass_types: Iterable[DataClassType]
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances
with the parsed args.
kwargs:
(Optional) Passed to `argparse.ArgumentParser()` in the regular way.
"""
super().__init__(**kwargs)
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
self.dataclass_types = dataclass_types
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)
def _add_dataclass_arguments(self, dtype: DataClassType):
for field in dataclasses.fields(dtype):
field_name = f"--{field.name}"
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
# it is provided as a third-party extension mechanism.
if isinstance(field.type, str):
raise ImportError(
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
"which can be opted in from Python 3.7 with `from __future__ import annotations`."
"We will add compatibility when Python 3.9 is released."
)
typestring = str(field.type)
for x in (int, float, str):
if typestring == f"typing.Union[{x.__name__}, NoneType]":
field.type = x
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool:
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no-{field.name}"
kwargs["dest"] = field.name
else:
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
else:
kwargs["required"] = True
self.add_argument(field_name, **kwargs)
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
This relies on argparse's `ArgumentParser.parse_known_args`.
See the doc at:
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
Args:
args:
List of strings to parse. The default is taken from sys.argv.
(same as argparse.ArgumentParser)
return_remaining_strings:
If true, also return a list of remaining argument strings.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they
were passed to the initializer.abspath
- if applicable, an additional namespace for more
(non-dataclass backed) arguments added to the parser
after initialization.
- The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args)
"""
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)
obj = dtype(**inputs)
outputs.append(obj)
if len(namespace.__dict__) > 0:
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
else:
return (*outputs,)
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TrainingArguments:
"""
TrainingArguments is the subset of the arguments we use in our example scripts
**which relate to the training loop itself**.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}
)
overwrite_output_dir: bool = field(
default=False, metadata={"help": "Overwrite the content of the output directory"}
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
evaluate_during_training: bool = field(
default=False, metadata={"help": "Run evaluation during training at each logging step."}
)
per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."})
per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."})
gradient_accumulation_steps: int = field(
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."})
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
max_steps: int = field(
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
save_total_limit: Optional[int] = field(
default=None,
metadata={
"help": "Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default"
},
)
eval_all_checkpoints: bool = field(
default=False,
metadata={
"help": "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
},
)
no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"})
seed: int = field(default=42, metadata={"help": "random seed for initialization"})
fp16: bool = field(
default=False,
metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"},
)
fp16_opt_level: str = field(
default="O1",
metadata={
"help": "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
import argparse
import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from transformers.hf_argparser import HfArgumentParser
from transformers.training_args import TrainingArguments
@dataclass
class BasicExample:
foo: int
bar: float
baz: str
flag: bool
@dataclass
class WithDefaultExample:
foo: int = 42
baz: str = field(default="toto", metadata={"help": "help message"})
@dataclass
class WithDefaultBoolExample:
foo: bool = False
baz: bool = True
class BasicEnum(Enum):
titi = "titi"
toto = "toto"
@dataclass
class EnumExample:
foo: BasicEnum = BasicEnum.toto
@dataclass
class OptionalExample:
foo: Optional[int] = None
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
baz: Optional[str] = None
class HfArgumentParserTest(unittest.TestCase):
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser) -> bool:
"""
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
"""
self.assertEqual(len(a._actions), len(b._actions))
for x, y in zip(a._actions, b._actions):
xx = {k: v for k, v in vars(x).items() if k != "container"}
yy = {k: v for k, v in vars(y).items() if k != "container"}
self.assertEqual(xx, yy)
def test_basic(self):
parser = HfArgumentParser(BasicExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", action="store_true")
self.argparsersEqual(parser, expected)
def test_with_default(self):
parser = HfArgumentParser(WithDefaultExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=42, type=int)
expected.add_argument("--baz", default="toto", type=str, help="help message")
self.argparsersEqual(parser, expected)
def test_with_default_bool(self):
parser = HfArgumentParser(WithDefaultBoolExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true")
expected.add_argument("--no-baz", action="store_false", dest="baz")
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True))
args = parser.parse_args(["--foo", "--no-baz"])
self.assertEqual(args, Namespace(foo=True, baz=False))
def test_with_enum(self):
parser = HfArgumentParser(EnumExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, BasicEnum.toto)
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, BasicEnum.titi)
def test_with_optional(self):
parser = HfArgumentParser(OptionalExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=None, type=int)
expected.add_argument("--bar", default=None, type=float, help="help message")
expected.add_argument("--baz", default=None, type=str)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42"))
def test_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment