"tests/models/vscode:/vscode.git/clone" did not exist on "6c1344444abe45b221953cb14b038b8a06299613"
Unverified Commit 8f6454bf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add proper documentation for Keras callbacks (#15374)

* Add proper documentation for Keras callbacks

* Add dummies
parent 2de90bee
......@@ -15,6 +15,10 @@ specific language governing permissions and limitations under the License.
When training a Transformers model with Keras, there are some library-specific callbacks available to automate common
tasks:
## KerasMetricCallback
[[autodoc]] KerasMetricCallback
## PushToHubCallback
[[autodoc]] keras_callbacks.PushToHubCallback
[[autodoc]] PushToHubCallback
......@@ -1550,7 +1550,7 @@ if is_tf_available():
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["PushToHubCallback"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = []
_import_structure["modeling_tf_utils"] = [
"TFPreTrainedModel",
......@@ -3486,7 +3486,7 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import PushToHubCallback
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
......
......@@ -63,7 +63,7 @@ class KerasMetricCallback(Callback):
supplied.
batch_size (`int`, *optional*):
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
predict_with_generate: (`bool`, *optional*, defaults to `False`):
predict_with_generate (`bool`, *optional*, defaults to `False`):
Whether we should use `model.generate()` to get outputs for the model.
"""
......@@ -240,18 +240,10 @@ class KerasMetricCallback(Callback):
class PushToHubCallback(Callback):
def __init__(
self,
output_dir: Union[str, Path],
save_strategy: Union[str, IntervalStrategy] = "epoch",
save_steps: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
**model_card_args
):
"""
Callback that will save and push the model to the Hub regularly.
Args:
output_dir (`str`):
The output directory where the model predictions and checkpoints will be written and synced with the
repository on the Hub.
......@@ -262,11 +254,11 @@ class PushToHubCallback(Callback):
- `"epoch"`: Save is done at the end of each epoch.
- `"steps"`: Save is done every `save_steps`
save_steps (`int`, *optional*):
The number of steps between saves when using the "steps" save_strategy.
The number of steps between saves when using the "steps" `save_strategy`.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
hub_model_id (`str`, *optional*):
The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
`"organization_name/model"`.
......@@ -277,8 +269,20 @@ class PushToHubCallback(Callback):
`huggingface-cli login`.
checkpoint (`bool`, *optional*, defaults to `False`):
Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
resumed. Only usable when *save_strategy* is *epoch*.
resumed. Only usable when `save_strategy` is `"epoch"`.
"""
def __init__(
self,
output_dir: Union[str, Path],
save_strategy: Union[str, IntervalStrategy] = "epoch",
save_steps: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
**model_card_args
):
super().__init__()
if checkpoint and save_strategy != "epoch":
raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
......
......@@ -21,6 +21,13 @@ def tf_top_k_top_p_filtering(*args, **kwargs):
requires_backends(tf_top_k_top_p_filtering, ["tf"])
class KerasMetricCallback(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class PushToHubCallback(metaclass=DummyObject):
_backends = ["tf"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment