Unverified Commit 8f6454bf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add proper documentation for Keras callbacks (#15374)

* Add proper documentation for Keras callbacks

* Add dummies
parent 2de90bee
...@@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License. ...@@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License.
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
tasks: tasks:
## KerasMetricCallback
[[autodoc]] KerasMetricCallback
## PushToHubCallback ## PushToHubCallback
[[autodoc]] keras_callbacks.PushToHubCallback [[autodoc]] PushToHubCallback
...@@ -1550,7 +1550,7 @@ if is_tf_available(): ...@@ -1550,7 +1550,7 @@ if is_tf_available():
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] _import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["PushToHubCallback"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = [] _import_structure["modeling_tf_outputs"] = []
_import_structure["modeling_tf_utils"] = [ _import_structure["modeling_tf_utils"] = [
"TFPreTrainedModel", "TFPreTrainedModel",
...@@ -3486,7 +3486,7 @@ if TYPE_CHECKING: ...@@ -3486,7 +3486,7 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_utils import tf_top_k_top_p_filtering from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import PushToHubCallback from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM, TFLayoutLMForMaskedLM,
......
...@@ -63,7 +63,7 @@ class KerasMetricCallback(Callback): ...@@ -63,7 +63,7 @@ class KerasMetricCallback(Callback):
supplied. supplied.
batch_size (`int`, *optional*): batch_size (`int`, *optional*):
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`. Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
predict_with_generate: (`bool`, *optional*, defaults to `False`): predict_with_generate (`bool`, *optional*, defaults to `False`):
Whether we should use `model.generate()` to get outputs for the model. Whether we should use `model.generate()` to get outputs for the model.
""" """
...@@ -240,18 +240,10 @@ class KerasMetricCallback(Callback): ...@@ -240,18 +240,10 @@ class KerasMetricCallback(Callback):
class PushToHubCallback(Callback): class PushToHubCallback(Callback):
def __init__(
self,
output_dir: Union[str, Path],
save_strategy: Union[str, IntervalStrategy] = "epoch",
save_steps: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
**model_card_args
):
""" """
Callback that will save and push the model to the Hub regularly.
Args:
output_dir (`str`): output_dir (`str`):
The output directory where the model predictions and checkpoints will be written and synced with the The output directory where the model predictions and checkpoints will be written and synced with the
repository on the Hub. repository on the Hub.
...@@ -262,11 +254,11 @@ class PushToHubCallback(Callback): ...@@ -262,11 +254,11 @@ class PushToHubCallback(Callback):
- `"epoch"`: Save is done at the end of each epoch. - `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps` - `"steps"`: Save is done every `save_steps`
save_steps (`int`, *optional*): save_steps (`int`, *optional*):
The number of steps between saves when using the "steps" save_strategy. The number of steps between saves when using the "steps" `save_strategy`.
tokenizer (`PreTrainedTokenizerBase`, *optional*): tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
hub_model_id (`str`, *optional*): hub_model_id (`str`, *optional*):
The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
`"organization_name/model"`. `"organization_name/model"`.
...@@ -277,8 +269,20 @@ class PushToHubCallback(Callback): ...@@ -277,8 +269,20 @@ class PushToHubCallback(Callback):
`huggingface-cli login`. `huggingface-cli login`.
checkpoint (`bool`, *optional*, defaults to `False`): checkpoint (`bool`, *optional*, defaults to `False`):
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when *save_strategy* is *epoch*. resumed. Only usable when `save_strategy` is `"epoch"`.
""" """
def __init__(
self,
output_dir: Union[str, Path],
save_strategy: Union[str, IntervalStrategy] = "epoch",
save_steps: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
**model_card_args
):
super().__init__() super().__init__()
if checkpoint and save_strategy != "epoch": if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!") raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
......
...@@ -21,6 +21,13 @@ def tf_top_k_top_p_filtering(*args, **kwargs): ...@@ -21,6 +21,13 @@ def tf_top_k_top_p_filtering(*args, **kwargs):
requires_backends(tf_top_k_top_p_filtering, ["tf"]) requires_backends(tf_top_k_top_p_filtering, ["tf"])
class KerasMetricCallback(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class PushToHubCallback(metaclass=DummyObject): class PushToHubCallback(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment