Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8f6454bf
Unverified
Commit
8f6454bf
authored
Jan 27, 2022
by
Sylvain Gugger
Committed by
GitHub
Jan 27, 2022
Browse files
Add proper documentation for Keras callbacks (#15374)
* Add proper documentation for Keras callbacks * Add dummies
parent
2de90bee
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
20 deletions
+35
-20
docs/source/main_classes/keras_callbacks.mdx
docs/source/main_classes/keras_callbacks.mdx
+5
-1
src/transformers/__init__.py
src/transformers/__init__.py
+2
-2
src/transformers/keras_callbacks.py
src/transformers/keras_callbacks.py
+21
-17
src/transformers/utils/dummy_tf_objects.py
src/transformers/utils/dummy_tf_objects.py
+7
-0
No files found.
docs/source/main_classes/keras_callbacks.mdx
View file @
8f6454bf
...
@@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License.
...
@@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License.
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
tasks:
tasks:
## KerasMetricCallback
[[autodoc]] KerasMetricCallback
## PushToHubCallback
## PushToHubCallback
[[autodoc]]
keras_callbacks.
PushToHubCallback
[[autodoc]] PushToHubCallback
src/transformers/__init__.py
View file @
8f6454bf
...
@@ -1550,7 +1550,7 @@ if is_tf_available():
...
@@ -1550,7 +1550,7 @@ if is_tf_available():
_import_structure
[
"benchmark.benchmark_args_tf"
]
=
[
"TensorFlowBenchmarkArguments"
]
_import_structure
[
"benchmark.benchmark_args_tf"
]
=
[
"TensorFlowBenchmarkArguments"
]
_import_structure
[
"benchmark.benchmark_tf"
]
=
[
"TensorFlowBenchmark"
]
_import_structure
[
"benchmark.benchmark_tf"
]
=
[
"TensorFlowBenchmark"
]
_import_structure
[
"generation_tf_utils"
]
=
[
"tf_top_k_top_p_filtering"
]
_import_structure
[
"generation_tf_utils"
]
=
[
"tf_top_k_top_p_filtering"
]
_import_structure
[
"keras_callbacks"
]
=
[
"PushToHubCallback"
]
_import_structure
[
"keras_callbacks"
]
=
[
"KerasMetricCallback"
,
"PushToHubCallback"
]
_import_structure
[
"modeling_tf_outputs"
]
=
[]
_import_structure
[
"modeling_tf_outputs"
]
=
[]
_import_structure
[
"modeling_tf_utils"
]
=
[
_import_structure
[
"modeling_tf_utils"
]
=
[
"TFPreTrainedModel"
,
"TFPreTrainedModel"
,
...
@@ -3486,7 +3486,7 @@ if TYPE_CHECKING:
...
@@ -3486,7 +3486,7 @@ if TYPE_CHECKING:
# Benchmarks
# Benchmarks
from
.benchmark.benchmark_tf
import
TensorFlowBenchmark
from
.benchmark.benchmark_tf
import
TensorFlowBenchmark
from
.generation_tf_utils
import
tf_top_k_top_p_filtering
from
.generation_tf_utils
import
tf_top_k_top_p_filtering
from
.keras_callbacks
import
PushToHubCallback
from
.keras_callbacks
import
KerasMetricCallback
,
PushToHubCallback
from
.modeling_tf_layoutlm
import
(
from
.modeling_tf_layoutlm
import
(
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
,
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
,
TFLayoutLMForMaskedLM
,
TFLayoutLMForMaskedLM
,
...
...
src/transformers/keras_callbacks.py
View file @
8f6454bf
...
@@ -63,7 +63,7 @@ class KerasMetricCallback(Callback):
...
@@ -63,7 +63,7 @@ class KerasMetricCallback(Callback):
supplied.
supplied.
batch_size (`int`, *optional*):
batch_size (`int`, *optional*):
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
predict_with_generate
:
(`bool`, *optional*, defaults to `False`):
predict_with_generate (`bool`, *optional*, defaults to `False`):
Whether we should use `model.generate()` to get outputs for the model.
Whether we should use `model.generate()` to get outputs for the model.
"""
"""
...
@@ -240,18 +240,10 @@ class KerasMetricCallback(Callback):
...
@@ -240,18 +240,10 @@ class KerasMetricCallback(Callback):
class
PushToHubCallback
(
Callback
):
class
PushToHubCallback
(
Callback
):
def
__init__
(
self
,
output_dir
:
Union
[
str
,
Path
],
save_strategy
:
Union
[
str
,
IntervalStrategy
]
=
"epoch"
,
save_steps
:
Optional
[
int
]
=
None
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
,
hub_model_id
:
Optional
[
str
]
=
None
,
hub_token
:
Optional
[
str
]
=
None
,
checkpoint
:
bool
=
False
,
**
model_card_args
):
"""
"""
Callback that will save and push the model to the Hub regularly.
Args:
output_dir (`str`):
output_dir (`str`):
The output directory where the model predictions and checkpoints will be written and synced with the
The output directory where the model predictions and checkpoints will be written and synced with the
repository on the Hub.
repository on the Hub.
...
@@ -262,11 +254,11 @@ class PushToHubCallback(Callback):
...
@@ -262,11 +254,11 @@ class PushToHubCallback(Callback):
- `"epoch"`: Save is done at the end of each epoch.
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`
- `"steps"`: Save is done every `save_steps`
save_steps (`int`, *optional*):
save_steps (`int`, *optional*):
The number of steps between saves when using the "steps" save_strategy.
The number of steps between saves when using the "steps"
`
save_strategy
`
.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
hub_model_id (`str`, *optional*):
hub_model_id (`str`, *optional*):
The name of the repository to keep in sync with the local
*
output_dir
*
. It can be a simple model ID in
The name of the repository to keep in sync with the local
`
output_dir
`
. It can be a simple model ID in
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
`"organization_name/model"`.
`"organization_name/model"`.
...
@@ -277,8 +269,20 @@ class PushToHubCallback(Callback):
...
@@ -277,8 +269,20 @@ class PushToHubCallback(Callback):
`huggingface-cli login`.
`huggingface-cli login`.
checkpoint (`bool`, *optional*, defaults to `False`):
checkpoint (`bool`, *optional*, defaults to `False`):
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when
*
save_strategy
*
is
*
epoch
*
.
resumed. Only usable when
`
save_strategy
`
is
`"
epoch
"`
.
"""
"""
def
__init__
(
self
,
output_dir
:
Union
[
str
,
Path
],
save_strategy
:
Union
[
str
,
IntervalStrategy
]
=
"epoch"
,
save_steps
:
Optional
[
int
]
=
None
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
,
hub_model_id
:
Optional
[
str
]
=
None
,
hub_token
:
Optional
[
str
]
=
None
,
checkpoint
:
bool
=
False
,
**
model_card_args
):
super
().
__init__
()
super
().
__init__
()
if
checkpoint
and
save_strategy
!=
"epoch"
:
if
checkpoint
and
save_strategy
!=
"epoch"
:
raise
ValueError
(
"Cannot save checkpoints when save_strategy is not 'epoch'!"
)
raise
ValueError
(
"Cannot save checkpoints when save_strategy is not 'epoch'!"
)
...
...
src/transformers/utils/dummy_tf_objects.py
View file @
8f6454bf
...
@@ -21,6 +21,13 @@ def tf_top_k_top_p_filtering(*args, **kwargs):
...
@@ -21,6 +21,13 @@ def tf_top_k_top_p_filtering(*args, **kwargs):
requires_backends
(
tf_top_k_top_p_filtering
,
[
"tf"
])
requires_backends
(
tf_top_k_top_p_filtering
,
[
"tf"
])
class
KerasMetricCallback
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"tf"
])
class
PushToHubCallback
(
metaclass
=
DummyObject
):
class
PushToHubCallback
(
metaclass
=
DummyObject
):
_backends
=
[
"tf"
]
_backends
=
[
"tf"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment