Unverified Commit b8686174 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Merge pull request #3934 from huggingface/examples_args_from_files

[qol] example scripts: parse args from .args file or JSON
parent f39217a5
...@@ -19,18 +19,15 @@ ...@@ -19,18 +19,15 @@
import dataclasses import dataclasses
import logging import logging
import os import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import ( from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
EvalPrediction,
GlueDataset,
GlueDataTrainingArguments,
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
...@@ -69,8 +66,14 @@ def main(): ...@@ -69,8 +66,14 @@ def main():
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, GlueDataTrainingArguments, TrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if ( if (
os.path.exists(training_args.output_dir) os.path.exists(training_args.output_dir)
......
import dataclasses import dataclasses
import json
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Iterable, NewType, Tuple, Union from typing import Any, Iterable, NewType, Tuple, Union
...@@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any) ...@@ -8,6 +11,10 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any) DataClassType = NewType("DataClassType", Any)
def trim_suffix(s: str, suffix: str):
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
class HfArgumentParser(ArgumentParser): class HfArgumentParser(ArgumentParser):
""" """
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
...@@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser): ...@@ -70,7 +77,9 @@ class HfArgumentParser(ArgumentParser):
kwargs["required"] = True kwargs["required"] = True
self.add_argument(field_name, **kwargs) self.add_argument(field_name, **kwargs)
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]: def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True
) -> Tuple[DataClass, ...]:
""" """
Parse command-line args into instances of the specified dataclass types. Parse command-line args into instances of the specified dataclass types.
...@@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser): ...@@ -84,6 +93,10 @@ class HfArgumentParser(ArgumentParser):
(same as argparse.ArgumentParser) (same as argparse.ArgumentParser)
return_remaining_strings: return_remaining_strings:
If true, also return a list of remaining argument strings. If true, also return a list of remaining argument strings.
look_for_args_file:
If true, will look for a ".args" file with the same base name
as the entry point script for this process, and will append its
potential content to the command line args.
Returns: Returns:
Tuple consisting of: Tuple consisting of:
...@@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser): ...@@ -95,6 +108,14 @@ class HfArgumentParser(ArgumentParser):
- The potential list of remaining argument strings. - The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args) (same as argparse.ArgumentParser.parse_known_args)
""" """
if look_for_args_file and len(sys.argv):
basename = trim_suffix(sys.argv[0], ".py")
args_file = Path(f"{basename}.args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.
namespace, remaining_args = self.parse_known_args(args=args) namespace, remaining_args = self.parse_known_args(args=args)
outputs = [] outputs = []
for dtype in self.dataclass_types: for dtype in self.dataclass_types:
...@@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser): ...@@ -111,3 +132,17 @@ class HfArgumentParser(ArgumentParser):
return (*outputs, remaining_args) return (*outputs, remaining_args)
else: else:
return (*outputs,) return (*outputs,)
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all,
instead loading a json file and populating the dataclass types.
"""
data = json.loads(Path(json_file).read_text())
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in data.items() if k in keys}
obj = dtype(**inputs)
outputs.append(obj)
return (*outputs,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment