Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1766fa21
Unverified
Commit
1766fa21
authored
May 10, 2022
by
Dom Miketa
Committed by
GitHub
May 10, 2022
Browse files
train args defaulting None marked as Optional (#17156)
Co-authored-by:
Dom Miketa
<
dmiketa@exscientia.co.uk
>
parent
6d80c92c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
src/transformers/training_args.py
src/transformers/training_args.py
+11
-9
src/transformers/training_args_tf.py
src/transformers/training_args_tf.py
+4
-4
No files found.
src/transformers/training_args.py
View file @
1766fa21
...
...
@@ -582,7 +582,7 @@ class TrainingArguments:
)
no_cuda
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Do not use CUDA even when it is available"
})
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed that will be set at the beginning of training."
})
data_seed
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Random seed to be used with data samplers."
})
data_seed
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Random seed to be used with data samplers."
})
bf16
:
bool
=
field
(
default
=
False
,
metadata
=
{
...
...
@@ -616,14 +616,14 @@ class TrainingArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether to use full float16 evaluation instead of 32-bit"
},
)
tf32
:
bool
=
field
(
tf32
:
Optional
[
bool
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"For distributed training: local_rank"
})
xpu_backend
:
str
=
field
(
xpu_backend
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The backend to be used for distributed training on Intel XPU."
,
"choices"
:
[
"mpi"
,
"ccl"
]},
)
...
...
@@ -648,7 +648,7 @@ class TrainingArguments:
dataloader_drop_last
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Drop the last incomplete batch if it is not divisible by the batch size."
}
)
eval_steps
:
int
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
eval_steps
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Run an evaluation every X steps."
})
dataloader_num_workers
:
int
=
field
(
default
=
0
,
metadata
=
{
...
...
@@ -770,14 +770,14 @@ class TrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"The path to a folder with a valid checkpoint for your model."
},
)
hub_model_id
:
str
=
field
(
hub_model_id
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to keep in sync with the local `output_dir`."
}
)
hub_strategy
:
HubStrategy
=
field
(
default
=
"every_save"
,
metadata
=
{
"help"
:
"The hub strategy to use when `--push_to_hub` is activated."
},
)
hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
hub_private_repo
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether the model repository is private or not."
})
gradient_checkpointing
:
bool
=
field
(
default
=
False
,
...
...
@@ -793,13 +793,15 @@ class TrainingArguments:
default
=
"auto"
,
metadata
=
{
"help"
:
"Deprecated. Use half_precision_backend instead"
,
"choices"
:
[
"auto"
,
"amp"
,
"apex"
]},
)
push_to_hub_model_id
:
str
=
field
(
push_to_hub_model_id
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository to which push the `Trainer`."
}
)
push_to_hub_organization
:
str
=
field
(
push_to_hub_organization
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the organization in with to which push the `Trainer`."
}
)
push_to_hub_token
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
})
push_to_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The token to use to push to the Model Hub."
}
)
_n_gpu
:
int
=
field
(
init
=
False
,
repr
=
False
,
default
=-
1
)
mp_parameters
:
str
=
field
(
default
=
""
,
...
...
src/transformers/training_args_tf.py
View file @
1766fa21
...
...
@@ -14,7 +14,7 @@
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
from
.training_args
import
TrainingArguments
from
.utils
import
cached_property
,
is_tf_available
,
logging
,
tf_required
...
...
@@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not.
"""
tpu_name
:
str
=
field
(
tpu_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Name of TPU"
},
)
tpu_zone
:
str
=
field
(
tpu_zone
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Zone of TPU"
},
)
gcp_project
:
str
=
field
(
gcp_project
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Name of Cloud TPU-enabled project"
},
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment