Unverified Commit 4f8361af authored by Jannis Born's avatar Jannis Born Committed by GitHub
Browse files

Unifying training argument type annotations (#17934)

* doc: Unify training arg type annotations

* wip: extracting enum type from Union

* blackening
parent 205bc415
......@@ -92,7 +92,11 @@ class HfArgumentParser(ArgumentParser):
" the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'."
)
if bool not in field.type.__args__:
if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
field.type = (
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
......
......@@ -20,7 +20,7 @@ import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from .debug_utils import DebugOption
from .trainer_utils import (
......@@ -493,7 +493,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: IntervalStrategy = field(
evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
......@@ -559,7 +559,7 @@ class TrainingArguments:
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
lr_scheduler_type: SchedulerType = field(
lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
......@@ -596,14 +596,14 @@ class TrainingArguments:
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_strategy: IntervalStrategy = field(
logging_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: IntervalStrategy = field(
save_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
......@@ -815,7 +815,7 @@ class TrainingArguments:
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
optim: OptimizerNames = field(
optim: Union[OptimizerNames, str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
......@@ -868,7 +868,7 @@ class TrainingArguments:
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
hub_strategy: Union[HubStrategy, str] = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment