Unverified Commit 4f8361af authored by Jannis Born's avatar Jannis Born Committed by GitHub
Browse files

Unifying training argument type annotations (#17934)

* doc: Unify training arg type annotations

* wip: extracting enum type from Union

* blackening
parent 205bc415
...@@ -92,7 +92,11 @@ class HfArgumentParser(ArgumentParser): ...@@ -92,7 +92,11 @@ class HfArgumentParser(ArgumentParser):
" the argument parser only supports one type per argument." " the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'." f" Problem encountered in field '{field.name}'."
) )
if bool not in field.type.__args__: if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`) # filter `NoneType` in Union (except for `Union[bool, NoneType]`)
field.type = ( field.type = (
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
......
...@@ -20,7 +20,7 @@ import warnings ...@@ -20,7 +20,7 @@ import warnings
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from .debug_utils import DebugOption from .debug_utils import DebugOption
from .trainer_utils import ( from .trainer_utils import (
...@@ -493,7 +493,7 @@ class TrainingArguments: ...@@ -493,7 +493,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: IntervalStrategy = field( evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no", default="no",
metadata={"help": "The evaluation strategy to use."}, metadata={"help": "The evaluation strategy to use."},
) )
...@@ -559,7 +559,7 @@ class TrainingArguments: ...@@ -559,7 +559,7 @@ class TrainingArguments:
default=-1, default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
) )
lr_scheduler_type: SchedulerType = field( lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear", default="linear",
metadata={"help": "The scheduler type to use."}, metadata={"help": "The scheduler type to use."},
) )
...@@ -596,14 +596,14 @@ class TrainingArguments: ...@@ -596,14 +596,14 @@ class TrainingArguments:
}, },
) )
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_strategy: IntervalStrategy = field( logging_strategy: Union[IntervalStrategy, str] = field(
default="steps", default="steps",
metadata={"help": "The logging strategy to use."}, metadata={"help": "The logging strategy to use."},
) )
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: IntervalStrategy = field( save_strategy: Union[IntervalStrategy, str] = field(
default="steps", default="steps",
metadata={"help": "The checkpoint save strategy to use."}, metadata={"help": "The checkpoint save strategy to use."},
) )
...@@ -815,7 +815,7 @@ class TrainingArguments: ...@@ -815,7 +815,7 @@ class TrainingArguments:
label_smoothing_factor: float = field( label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
) )
optim: OptimizerNames = field( optim: Union[OptimizerNames, str] = field(
default="adamw_hf", default="adamw_hf",
metadata={"help": "The optimizer to use."}, metadata={"help": "The optimizer to use."},
) )
...@@ -868,7 +868,7 @@ class TrainingArguments: ...@@ -868,7 +868,7 @@ class TrainingArguments:
hub_model_id: Optional[str] = field( hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
) )
hub_strategy: HubStrategy = field( hub_strategy: Union[HubStrategy, str] = field(
default="every_save", default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment