"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "eb849f6604c7dcc0e96d68f4851e52e253b9f0e5"
Unverified Commit 1766fa21 authored by Dom Miketa's avatar Dom Miketa Committed by GitHub
Browse files

train args defaulting None marked as Optional (#17156)


Co-authored-by: default avatarDom Miketa <dmiketa@exscientia.co.uk>
parent 6d80c92c
...@@ -582,7 +582,7 @@ class TrainingArguments: ...@@ -582,7 +582,7 @@ class TrainingArguments:
) )
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
bf16: bool = field( bf16: bool = field(
default=False, default=False,
metadata={ metadata={
...@@ -616,14 +616,14 @@ class TrainingArguments: ...@@ -616,14 +616,14 @@ class TrainingArguments:
default=False, default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
) )
tf32: bool = field( tf32: Optional[bool] = field(
default=None, default=None,
metadata={ metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change." "help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
}, },
) )
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field( xpu_backend: Optional[str] = field(
default=None, default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]}, metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
) )
...@@ -648,7 +648,7 @@ class TrainingArguments: ...@@ -648,7 +648,7 @@ class TrainingArguments:
dataloader_drop_last: bool = field( dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
) )
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field( dataloader_num_workers: int = field(
default=0, default=0,
metadata={ metadata={
...@@ -770,14 +770,14 @@ class TrainingArguments: ...@@ -770,14 +770,14 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."}, metadata={"help": "The path to a folder with a valid checkpoint for your model."},
) )
hub_model_id: str = field( hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
) )
hub_strategy: HubStrategy = field( hub_strategy: HubStrategy = field(
default="every_save", default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
) )
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."}) hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field( gradient_checkpointing: bool = field(
default=False, default=False,
...@@ -793,13 +793,15 @@ class TrainingArguments: ...@@ -793,13 +793,15 @@ class TrainingArguments:
default="auto", default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]}, metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
) )
push_to_hub_model_id: str = field( push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
) )
push_to_hub_organization: str = field( push_to_hub_organization: Optional[str] = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."} default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
) )
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) push_to_hub_token: Optional[str] = field(
default=None, metadata={"help": "The token to use to push to the Model Hub."}
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field( mp_parameters: str = field(
default="", default="",
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Optional, Tuple
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required from .utils import cached_property, is_tf_available, logging, tf_required
...@@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments): ...@@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not. Whether to activate the XLA compilation or not.
""" """
tpu_name: str = field( tpu_name: Optional[str] = field(
default=None, default=None,
metadata={"help": "Name of TPU"}, metadata={"help": "Name of TPU"},
) )
tpu_zone: str = field( tpu_zone: Optional[str] = field(
default=None, default=None,
metadata={"help": "Zone of TPU"}, metadata={"help": "Zone of TPU"},
) )
gcp_project: str = field( gcp_project: Optional[str] = field(
default=None, default=None,
metadata={"help": "Name of Cloud TPU-enabled project"}, metadata={"help": "Name of Cloud TPU-enabled project"},
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment