Commit a6d78dd4 authored by Mark Daoust's avatar Mark Daoust Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 456569753
parent 45dc706d
# Public docs for TensorFlow Models
This directory contains the top-level public documentation for
[TensorFlow Models](https://github.com/tensorflow/models)
This directory is mirrored to https://tensorflow.org/tfmodels, and is mainly
concerned with documenting the tools provided in the `tensorflow_models` pip
package (including `orbit`).
Api-reference pages are
[available on the site](https://www.tensorflow.org/api_docs/more).
The
[Official Models](https://github.com/tensorflow/models/blob/master/official/projects)
and [Research Models](https://github.com/tensorflow/models/blob/master/research)
directories are not described in detail here, refer to the individual project
directories for more information.
# Model Garden overview
The TensorFlow Model Garden provides implementations of many state-of-the-art
machine learning (ML) models for vision and natural language processing (NLP),
as well as workflow tools to let you quickly configure and run those models on
standard datasets. Whether you are looking to benchmark performance for a
well-known model, verify the results of recently released research, or extend
existing models, the Model Garden can help you drive your ML research and
applications forward.
The Model Garden includes the following resources for machine learning
developers:
- [**Official models**](#official) for vision and NLP, maintained by Google
engineers
- [**Research models**](#research) published as part of ML research papers
- [**Training experiment framework**](#training_framework) for fast,
declarative training configuration of official models
- [**Specialized ML operations**](#ops) for vision and natural language
processing (NLP)
- [**Model training loop**](#orbit) management with Orbit
These resources are built to be used with the TensorFlow Core framework and
integrate with your existing TensorFlow development projects. Model
Garden resources are also provided under an [open
source](https://github.com/tensorflow/models/blob/master/LICENSE) license, so
you can freely extend and distribute the models and tools.
Practical ML models are computationally intensive to train and run, and may
require accelerators such as Graphical Processing Units (GPUs) and Tensor
Processing Units (TPUs). Most of the models in Model Garden were trained on
large datasets using TPUs. However, you can also train and run these models on
GPU and CPU processors.
## Model Garden models
The machine learning models in the Model Garden include full code so you can
test, train, or re-train them for research and experimentation. The Model Garden
includes two primary categories of models: *official models* and *research
models*.
### Official models {:#official}
The [Official Models](https://github.com/tensorflow/models/tree/master/official)
repository is a collection of state-of-the-art models, with a focus on
vision and natural language processing (NLP).
These models are implemented using current TensorFlow 2.x high-level
APIs. Model libraries in this repository are optimized for fast performance and
actively maintained by Google engineers. The official models include additional
metadata you can use to quickly configure experiments using the Model Garden
[training experiment framework](#training_framework).
### Research models {:#research}
The [Research Models](https://github.com/tensorflow/models/tree/master/research)
repository is a collection of models published as code resources for research
papers. These models are implemented using both TensorFlow 1.x and 2.x. Model
libraries in the research folder are supported by the code owners and the
research community.
## Training experiment framework {:#training_framework}
The Model Garden training experiment framework lets you quickly assemble and run
training experiments using its official models and standard datasets. The
training framework uses additional metadata included with the Model Garden's
official models to allow you to configure models quickly using a declarative
programming model. You can define a training experiment using Python commands in
the
[TensorFlow Model library](https://www.tensorflow.org/api_docs/python/tfm/core)
or configure training using a YAML configuration file, like this
[example](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml).
The training framework uses
[`tfm.core.base_trainer.ExperimentConfig`](https://www.tensorflow.org/api_docs/python/tfm/core/base_trainer/ExperimentConfig)
as the configuration object, which contains the following top-level
configuration objects:
- [`runtime`](https://www.tensorflow.org/api_docs/python/tfm/core/base_task/RuntimeConfig):
Defines the processing hardware, distribution strategy, and other
performance optimizations
- [`task`](https://www.tensorflow.org/api_docs/python/tfm/core/config_definitions/TaskConfig):
Defines the model, training data, losses, and initialization
- [`trainer`](https://www.tensorflow.org/api_docs/python/tfm/core/base_trainer/TrainerConfig):
Defines the optimizer, training loops, evaluation loops, summaries, and
checkpoints
For a complete example using the Model Garden training experiment framework, see
the [Image classification with Model Garden](vision/image_classification.ipynb)
tutorial. For information on the training experiment framework, check out the
[TensorFlow Models API documentation](https://tensorflow.org/api_docs/python/tfm/core).
If you are looking for a solution to manage training loops for your model
training experiments, check out [Orbit](#orbit).
## Specialized ML operations {:#ops}
The Model Garden contains many vision and NLP operations specifically designed
to execute state-of-the-art models that run efficiently on GPUs and TPUs. Review
the TensorFlow Models Vision library API docs for a list of specialized
[vision operations](https://www.tensorflow.org/api_docs/python/tfm/vision).
Review the TensorFlow Models NLP Library API docs for a list of
[NLP operations](https://www.tensorflow.org/api_docs/python/tfm/nlp). These
libraries also include additional utility functions used for vision and NLP data
processing, training, and model execution.
## Training loops with Orbit {:#orbit}
The Orbit tool is a flexible, lightweight library designed to make it easier to
write custom training loops in TensorFlow 2.x, and works well with the Model
Garden [training experiment framework](#training_framework). Orbit handles
common model training tasks such as saving checkpoints, running model
evaluations, and setting up summary writing. It seamlessly integrates with
`tf.distribute` and supports running on different device types, including CPU,
GPU, and TPU hardware. The Orbit tool is also [open
source](https://github.com/tensorflow/models/blob/master/orbit/LICENSE), so you
can extend and adapt to your model training needs.
You generally train TensorFlow models by writing a
[custom training loop](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch),
or using the high-level Keras
[Model.fit](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)
function. For simple models, you can define and manage a custom training loop
with low-level TensorFlow methods such as `tf.GradientTape` or `tf.function`.
Alternatively, you can use the high-level Keras `Model.fit`.
However, if your model is complex and your training loop requires more flexible
control or customization, then you should use Orbit. You can define most of your
training loop by the `orbit.AbstractTrainer` or `orbit.StandardTrainer` class.
Learn more about the Orbit tool in the
[Orbit API documentation](https://www.tensorflow.org/api_docs/python/orbit).
Note: You can customize how the Keras API executes training. Mainly you must
override the `Model.train_step` method or use `keras.callbacks` like
`callbacks.ModelCheckpoint` or `callbacks.TensorBoard`. For more information
about modifying the behavior of `train_step`, check out the
[Customize what happens in Model.fit](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)
page.
...@@ -48,16 +48,16 @@ ...@@ -48,16 +48,16 @@
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
"\u003c/table\u003e" "\u003c/table\u003e"
] ]
...@@ -191,7 +191,7 @@ ...@@ -191,7 +191,7 @@
"id": "Qe2UWI6_tsHo" "id": "Qe2UWI6_tsHo"
}, },
"source": [ "source": [
"`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb). We skip the code that trains the model here.\n", "`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the [Fine tuning bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook. We skip the code that trains the model here.\n",
"\n", "\n",
"After training, we can apply the model to do prediction.\n" "After training, we can apply the model to do prediction.\n"
] ]
......
...@@ -48,16 +48,16 @@ ...@@ -48,16 +48,16 @@
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/tutorials/decoding_api_in_tf_nlp.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/decoding_api\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/decoding_api_in_tf_nlp.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/decoding_api_in_tf_nlp.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/decoding_api_in_tf_nlp.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
"\u003c/table\u003e" "\u003c/table\u003e"
] ]
...@@ -226,10 +226,10 @@ ...@@ -226,10 +226,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"params = {\n", "params = {\n",
" 'num_heads': 2\n", " 'num_heads': 2,\n",
" 'num_layers': 2\n", " 'num_layers': 2,\n",
" 'batch_size': 2\n", " 'batch_size': 2,\n",
" 'n_dims': 256\n", " 'n_dims': 256,\n",
" 'max_decode_length': 4}" " 'max_decode_length': 4}"
] ]
}, },
......
...@@ -48,16 +48,16 @@ ...@@ -48,16 +48,16 @@
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
" \u003ctd\u003e\n", " \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n", " \u003c/td\u003e\n",
"\u003c/table\u003e" "\u003c/table\u003e"
] ]
...@@ -538,7 +538,7 @@ ...@@ -538,7 +538,7 @@
"id": "mzBqOylZo3og" "id": "mzBqOylZo3og"
}, },
"source": [ "source": [
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example." "With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the [Fine tune_bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook for the full example."
] ]
} }
], ],
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Training with Orbit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/orbit\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "456h0idS2Xcq"
},
"source": [
"This example will work through fine-tuning a BERT model using the [Orbit](https://www.tensorflow.org/api_docs/python/orbit) training library.\n",
"\n",
"Orbit is a flexible, lightweight library designed to make it easy to write [custom training loops](https://www.tensorflow.org/tutorials/distribute/custom_training) in TensorFlow. Orbit handles common model training tasks such as saving checkpoints, running model evaluations, and setting up summary writing, while giving users full control over implementing the inner training loop. It integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU).\n",
"\n",
"Most examples on [tensorflow.org](https://www.tensorflow.org/) use custom training loops or [model.fit()](https://www.tensorflow.org/api_docs/python/tf/keras/Model) from Keras. Orbit is a good alternative to `model.fit` if your model is complex and your training loop requires more flexibility, control, or customization. Also, using Orbit can simplify the code when there are many different model architectures that all use the same custom training loop.\n",
"\n",
"This tutorial focuses on setting up and using Orbit, rather than details about BERT, model construction, and data processing. For more in-depth tutorials on these topics, refer to the following tutorials:\n",
"\n",
"* [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) - which goes into detail on these sub-topics.\n",
"* [Fine tune BERT for GLUE on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) - which generalizes the code to run any BERT configuration on any [GLUE](https://www.tensorflow.org/datasets/catalog/glue) sub-task, and runs on TPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TJ4m3khW3p_W"
},
"source": [
"## Install the TensorFlow Models package\n",
"\n",
"Install and import the necessary packages, then configure all the objects necessary for training a model.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FZlj0U8Aq9Gt"
},
"outputs": [],
"source": [
"# Uninstall opencv-python to avoid a conflict (in Colab) with the opencv-python-headless package that tf-models uses.\n",
"!pip uninstall -y opencv-python\n",
"!pip install -U -q \"tensorflow\u003e=2.9.0\" \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MEJkRrmapr16"
},
"source": [
"The `tf-models-official` package contains both the `orbit` and `tensorflow_models` modules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dUVPW84Zucuq"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"import orbit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18Icocf3lwYD"
},
"source": [
"## Setup for training\n",
"\n",
"This tutorial does not focus on configuring the environment, building the model and optimizer, and loading data. All these techniques are covered in more detail in the [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) and [Fine tune BERT with GLUE](https://www.tensorflow.org/text/tutorials/bert_glue) tutorials.\n",
"\n",
"To view how the training is set up for this tutorial, expand the rest of this section.\n",
"\n",
" \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n",
" \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eExpand Section\u003c/button\u003e --\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ljy0z-i3okCS"
},
"source": [
"### Import the necessary packages\n",
"\n",
"Import the BERT model and dataset building library from [Tensorflow Model Garden](https://github.com/tensorflow/models)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCBo6wxA2b5n"
},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"import pathlib\n",
"import tempfile\n",
"import time\n",
"\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PG1kwhnvq3VC"
},
"outputs": [],
"source": [
"from official.nlp.data import sentence_prediction_dataloader\n",
"from official.nlp import optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PsbhUV_p3wxN"
},
"source": [
"### Configure the distribution strategy\n",
"\n",
"While `tf.distribute` won't help the model's runtime if you're running on a single machine or GPU, it's necessary for TPUs. Setting up a distribution strategy allows you to use the same code regardless of the configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PG702dqstXIk"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
" tf.config.experimental_connect_to_cluster(resolver)\n",
" tf.tpu.experimental.initialize_tpu_system(resolver)\n",
" strategy = tf.distribute.TPUStrategy(resolver)\n",
"else:\n",
" strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eaQgM98deAMu"
},
"source": [
"For more information about the TPU setup, refer to the [TPU guide](https://www.tensorflow.org/guide/tpu)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7aOxMLLV32Zm"
},
"source": [
"### Create a model and an optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YRdWzOfK3_56"
},
"outputs": [],
"source": [
"max_seq_length = 128\n",
"learning_rate = 3e-5\n",
"num_train_epochs = 3\n",
"train_batch_size = 32\n",
"eval_batch_size = 64\n",
"\n",
"train_data_size = 3668\n",
"steps_per_epoch = int(train_data_size / train_batch_size)\n",
"\n",
"train_steps = steps_per_epoch * num_train_epochs\n",
"warmup_steps = int(train_steps * 0.1)\n",
"\n",
"print(\"train batch size: \", train_batch_size)\n",
"print(\"train epochs: \", num_train_epochs)\n",
"print(\"steps_per_epoch: \", steps_per_epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BVw3886Ysse6"
},
"outputs": [],
"source": [
"model_dir = pathlib.Path(tempfile.mkdtemp())\n",
"print(model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mu9cV7ew-cVe"
},
"source": [
"\n",
"Create a BERT Classifier model and a simple optimizer. They must be created inside `strategy.scope` so that the variables can be distributed. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gmwtX0cp-mj5"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" encoder_network = tfm.nlp.encoders.build_encoder(\n",
" tfm.nlp.encoders.EncoderConfig(type=\"bert\"))\n",
" classifier_model = tfm.nlp.models.BertClassifier(\n",
" network=encoder_network, num_classes=2)\n",
"\n",
" optimizer = optimization.create_optimizer(\n",
" init_lr=3e-5,\n",
" num_train_steps=steps_per_epoch * num_train_epochs,\n",
" num_warmup_steps=warmup_steps,\n",
" end_lr=0.0,\n",
" optimizer_type='adamw')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jwJSfewG5jVV"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQy5pYgAf8Ft"
},
"source": [
"### Initialize from a Checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6CE14GEybgRR"
},
"outputs": [],
"source": [
"bert_dir = 'gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12/'\n",
"tf.io.gfile.listdir(bert_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x7fwxz9xidKt"
},
"outputs": [],
"source": [
"bert_checkpoint = bert_dir + 'bert_model.ckpt'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q7EfwVCRe7N_"
},
"outputs": [],
"source": [
"def init_from_ckpt_fn():\n",
" init_checkpoint = tf.train.Checkpoint(**classifier_model.checkpoint_items)\n",
" with strategy.scope():\n",
" (init_checkpoint\n",
" .read(bert_checkpoint)\n",
" .expect_partial()\n",
" .assert_existing_objects_matched())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M0LUMlsde-2f"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" init_from_ckpt_fn()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gAuns4vN_IYV"
},
"source": [
"\n",
"To use Orbit, create a `tf.train.CheckpointManager` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i7NwM1Jq_MX7"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=classifier_model, optimizer=optimizer)\n",
"checkpoint_manager = tf.train.CheckpointManager(\n",
" checkpoint,\n",
" directory=model_dir,\n",
" max_to_keep=5,\n",
" step_counter=optimizer.iterations,\n",
" checkpoint_interval=steps_per_epoch,\n",
" init_fn=init_from_ckpt_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nzeiAFhcCOAo"
},
"source": [
"### Create distributed datasets\n",
"\n",
"As a shortcut for this tutorial, the [GLUE/MPRC dataset](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc) has been converted to a pair of [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) files containing serialized `tf.train.Example` protos.\n",
"\n",
"The data was converted using [this script](https://github.com/tensorflow/models/blob/r2.9.0/official/nlp/data/create_finetuning_data.py).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZVfbiT1dCnDk"
},
"outputs": [],
"source": [
"train_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_train.tf_record\"\n",
"eval_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_eval.tf_record\"\n",
"\n",
"def _dataset_fn(input_file_pattern, \n",
" global_batch_size, \n",
" is_training, \n",
" input_context=None):\n",
" data_config = sentence_prediction_dataloader.SentencePredictionDataConfig(\n",
" input_path=input_file_pattern,\n",
" seq_length=max_seq_length,\n",
" global_batch_size=global_batch_size,\n",
" is_training=is_training)\n",
" return sentence_prediction_dataloader.SentencePredictionDataLoader(\n",
" data_config).load(input_context=input_context)\n",
"\n",
"train_dataset = orbit.utils.make_distributed_dataset(\n",
" strategy, _dataset_fn, input_file_pattern=train_data_path,\n",
" global_batch_size=train_batch_size, is_training=True)\n",
"eval_dataset = orbit.utils.make_distributed_dataset(\n",
" strategy, _dataset_fn, input_file_pattern=eval_data_path,\n",
" global_batch_size=eval_batch_size, is_training=False)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dPgiDBQCjsXW"
},
"source": [
"### Create a loss function\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7MCUmmo2jvXl"
},
"outputs": [],
"source": [
"def loss_fn(labels, logits):\n",
" \"\"\"Classification loss.\"\"\"\n",
" labels = tf.squeeze(labels)\n",
" log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
" one_hot_labels = tf.one_hot(\n",
" tf.cast(labels, dtype=tf.int32), depth=2, dtype=tf.float32)\n",
" per_example_loss = -tf.reduce_sum(\n",
" tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)\n",
" return tf.reduce_mean(per_example_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ohlO-8FQkwsr"
},
"source": [
" \u003c/devsite-expandable\u003e\u003c/div\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymhbvPaEJ96T"
},
"source": [
"## Controllers, Trainers and Evaluators\n",
"\n",
"When using Orbit, the `orbit.Controller` class drives the training. The Controller handles the details of distribution strategies, step counting, TensorBoard summaries, and checkpointing.\n",
"\n",
"To implement the training and evaluation, pass a `trainer` and `evaluator`, which are subclass instances of `orbit.AbstractTrainer` and `orbit.AbstractEvaluator`. Keeping with Orbit's light-weight design, these two classes have a minimal interface.\n",
"\n",
"The Controller drives training and evaluation by calling `trainer.train(num_steps)` and `evaluator.evaluate(num_steps)`. These `train` and `evaluate` methods return a dictionary of results for logging.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a6sU2vBeyXtu"
},
"source": [
"Training is broken into chunks of length `num_steps`. This is set by the Controller's [`steps_per_loop`](https://tensorflow.org/api_docs/python/orbit/Controller#args) argument. With the trainer and evaluator abstract base classes, the meaning of `num_steps` is entirely determined by the implementer.\n",
"\n",
"Some common examples include:\n",
"\n",
"* Having the chunks represent dataset-epoch boundaries, like the default keras setup. \n",
"* Using it to more efficiently dispatch a number of training steps to an accelerator with a single `tf.function` call (like the `steps_per_execution` argument to `Model.compile`). \n",
"* Subdividing into smaller chunks as needed.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p4mXGIRJsf1j"
},
"source": [
"### StandardTrainer and StandardEvaluator\n",
"\n",
"Orbit provides two additional classes, `orbit.StandardTrainer` and `orbit.StandardEvaluator`, to give more structure around the training and evaluation loops.\n",
"\n",
"With StandardTrainer, you only need to set `train_loop_begin`, `train_step`, and `train_loop_end`. The base class handles the loops, dataset logic, and `tf.function` (according to the options set by their `orbit.StandardTrainerOptions`). This is simpler than `orbit.AbstractTrainer`, which requires you to handle the entire loop. StandardEvaluator has a similar structure and simplification to StandardTrainer.\n",
"\n",
"This is effectively an implementation of the `steps_per_execution` approach used by Keras."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-hvZ8PvohmR5"
},
"source": [
"Contrast this with Keras, where training is divided both into epochs (a single pass over the dataset) and `steps_per_execution`(set within [`Model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile). In Keras, metric averages are typically accumulated over an epoch, and reported \u0026 reset between epochs. For efficiency, `steps_per_execution` only controls the number of training steps made per call.\n",
"\n",
"In this simple case, `steps_per_loop` (within `StandardTrainer`) will handle both the metric resets and the number of steps per call. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NoDFN1L-1jIu"
},
"source": [
"The minimal setup when using these base classes is to implement the methods as follows:\n",
"\n",
"1. `StandardTrainer.train_loop_begin` - Reset your training metrics.\n",
"2. `StandardTrainer.train_step` - Apply a single gradient update.\n",
"3. `StandardTrainer.train_loop_end` - Report your training metrics.\n",
"\n",
"and\n",
"\n",
"4. `StandardEvaluator.eval_begin` - Reset your evaluation metrics.\n",
"5. `StandardEvaluator.eval_step` - Run a single evaluation setep.\n",
"6. `StandardEvaluator.eval_reduce` - This is not necessary in this simple setup.\n",
"7. `StandardEvaluator.eval_end` - Report your evaluation metrics.\n",
"\n",
"Depending on the settings, the base class may wrap the `train_step` and `eval_step` code in `tf.function` or `tf.while_loop`, which has some limitations compared to standard python."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3KPA0NDZt2JD"
},
"source": [
"### Define the trainer class"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6LDPsvJwfuPR"
},
"source": [
"In this section you'll create a subclass of `orbit.StandardTrainer` for this task. \n",
"\n",
"Note: To better explain the `BertClassifierTrainer` class, this section defines each method as a stand-alone function and assembles them into a class at the end.\n",
"\n",
"The trainer needs access to the training data, model, optimizer, and distribution strategy. Pass these as arguments to the initializer.\n",
"\n",
"Define a single training metric, `training_loss`, using `tf.keras.metrics.Mean`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6DQYZN5ax-MG"
},
"outputs": [],
"source": [
"def trainer_init(self,\n",
" train_dataset,\n",
" model,\n",
" optimizer,\n",
" strategy):\n",
" self.strategy = strategy\n",
" with self.strategy.scope():\n",
" self.model = model\n",
" self.optimizer = optimizer\n",
" self.global_step = self.optimizer.iterations\n",
" \n",
"\n",
" self.train_loss = tf.keras.metrics.Mean(\n",
" 'training_loss', dtype=tf.float32)\n",
" orbit.StandardTrainer.__init__(self, train_dataset)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QOwHD7U5hVue"
},
"source": [
"Before starting a run of the training loop, the `train_loop_begin` method will reset the `train_loss` metric."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AkpcHqXShWL0"
},
"outputs": [],
"source": [
"def train_loop_begin(self):\n",
" self.train_loss.reset_states()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UjtFOFyxn2BB"
},
"source": [
"The `train_step` is a straight-forward loss-calculation and gradient update that is run by the distribution strategy. This is accomplished by defining the gradient step as a nested function (`step_fn`).\n",
"\n",
"The method receives `tf.distribute.DistributedIterator` to handle the [distributed input](https://www.tensorflow.org/tutorials/distribute/input). The method uses `Strategy.run` to execute `step_fn` and feeds it from the distributed iterator.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QuPwNnT5I-GP"
},
"outputs": [],
"source": [
"def train_step(self, iterator):\n",
"\n",
" def step_fn(inputs):\n",
" labels = inputs.pop(\"label_ids\")\n",
" with tf.GradientTape() as tape:\n",
" model_outputs = self.model(inputs, training=True)\n",
" # Raw loss is used for reporting in metrics/logs.\n",
" raw_loss = loss_fn(labels, model_outputs)\n",
" # Scales down the loss for gradients to be invariant from replicas.\n",
" loss = raw_loss / self.strategy.num_replicas_in_sync\n",
"\n",
" grads = tape.gradient(loss, self.model.trainable_variables)\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
" # For reporting, the metric takes the mean of losses.\n",
" self.train_loss.update_state(raw_loss)\n",
"\n",
" self.strategy.run(step_fn, args=(next(iterator),))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VmQNwx5QpyDt"
},
"source": [
"The `orbit.StandardTrainer` handles the `@tf.function` and loops.\n",
"\n",
"After running through `num_steps` of training, `StandardTrainer` calls `train_loop_end`. The function returns the metric results:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GqCyVk1zzGod"
},
"outputs": [],
"source": [
"def train_loop_end(self):\n",
" return {\n",
" self.train_loss.name: self.train_loss.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xvmLONl80KUv"
},
"source": [
"Build a subclass of `orbit.StandardTrainer` with those methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oRoL7VE6xt1G"
},
"outputs": [],
"source": [
"class BertClassifierTrainer(orbit.StandardTrainer):\n",
" __init__ = trainer_init\n",
" train_loop_begin = train_loop_begin\n",
" train_step = train_step\n",
" train_loop_end = train_loop_end"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yjG4QAWj1B00"
},
"source": [
"### Define the evaluator class\n",
"\n",
"Note: Like the previous section, this section defines each method as a stand-alone function and assembles them into a `BertClassifierEvaluator` class at the end.\n",
"\n",
"The evaluator is even simpler for this task. It needs access to the evaluation dataset, the model, and the strategy. After saving references to those objects, the constructor just needs to create the metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cvX7seCY1CWj"
},
"outputs": [],
"source": [
"def evaluator_init(self,\n",
" eval_dataset,\n",
" model,\n",
" strategy):\n",
" self.strategy = strategy\n",
" with self.strategy.scope():\n",
" self.model = model\n",
" \n",
" self.eval_loss = tf.keras.metrics.Mean(\n",
" 'evaluation_loss', dtype=tf.float32)\n",
" self.eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
" name='accuracy', dtype=tf.float32)\n",
" orbit.StandardEvaluator.__init__(self, eval_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0r-z-XK7ybyX"
},
"source": [
"Similar to the trainer, the `eval_begin` and `eval_end` methods just need to reset the metrics before the loop and then report the results after the loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7VVb0Tg6yZjI"
},
"outputs": [],
"source": [
"def eval_begin(self):\n",
" self.eval_accuracy.reset_states()\n",
" self.eval_loss.reset_states()\n",
"\n",
"def eval_end(self):\n",
" return {\n",
" self.eval_accuracy.name: self.eval_accuracy.result(),\n",
" self.eval_loss.name: self.eval_loss.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iDOZcQvttdmZ"
},
"source": [
"The `eval_step` method works like `train_step`. The inner `step_fn` defines the actual work of calculating the loss \u0026 accuracy and updating the metrics. The outer `eval_step` receives `tf.distribute.DistributedIterator` as input, and uses `Strategy.run` to launch the distributed execution to `step_fn`, feeding it from the distributed iterator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JLJnYuuGJjvd"
},
"outputs": [],
"source": [
"def eval_step(self, iterator):\n",
"\n",
" def step_fn(inputs):\n",
" labels = inputs.pop(\"label_ids\")\n",
" model_outputs = self.model(inputs, training=True)\n",
" loss = loss_fn(labels, model_outputs)\n",
" self.eval_loss.update_state(loss)\n",
" self.eval_accuracy.update_state(labels, model_outputs)\n",
"\n",
" self.strategy.run(step_fn, args=(next(iterator),))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gt3hh0V30QcP"
},
"source": [
"Build a subclass of `orbit.StandardEvaluator` with those methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3zqyLxfNyCgA"
},
"outputs": [],
"source": [
"class BertClassifierEvaluator(orbit.StandardEvaluator):\n",
" __init__ = evaluator_init\n",
" eval_begin = eval_begin\n",
" eval_end = eval_end\n",
" eval_step = eval_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aK9gEja9qPOc"
},
"source": [
"### End-to-end training and evaluation\n",
"\n",
"To run the training and evaluation, simply create the trainer, evaluator, and `orbit.Controller` instances. Then call the `Controller.train_and_evaluate` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PqQetxyXqRA9"
},
"outputs": [],
"source": [
"trainer = BertClassifierTrainer(\n",
" train_dataset, classifier_model, optimizer, strategy)\n",
"\n",
"evaluator = BertClassifierEvaluator(\n",
" eval_dataset, classifier_model, strategy)\n",
"\n",
"controller = orbit.Controller(\n",
" trainer=trainer,\n",
" evaluator=evaluator,\n",
" global_step=trainer.global_step,\n",
" steps_per_loop=20,\n",
" checkpoint_manager=checkpoint_manager)\n",
"\n",
"result = controller.train_and_evaluate(\n",
" train_steps=steps_per_epoch * num_train_epochs,\n",
" eval_steps=-1,\n",
" eval_interval=steps_per_epoch)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"Tce3stUlHN0L"
],
"name": "Orbit Tutorial.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Image classification with Model Garden"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ta_nFXaVAqLD"
},
"source": [
"This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
"\n",
"Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
"\n",
"This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
"\n",
"This tutorial demonstrates how to:\n",
"1. Use models from the TensorFlow Models package.\n",
"2. Fine-tune a pre-built ResNet for image classification.\n",
"3. Export the tuned ResNet model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G2FlaQcEPOER"
},
"source": [
"## Setup\n",
"\n",
"Install and import the necessary modules. This tutorial uses the `tf-models-nightly` version of Model Garden.\n",
"\n",
"Note: Upgrading TensorFlow to 2.9 in Colab breaks GPU support, so this colab is set to run on CPU until the Colab runtimes are updated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XvWfdCrvrV5W"
},
"outputs": [],
"source": [
"!pip uninstall -y opencv-python\n",
"!pip install -U -q \"tensorflow\u003e=2.9.0\" \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CKYMTPjOE400"
},
"source": [
"Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wlon1uoIowmZ"
},
"outputs": [],
"source": [
"import pprint\n",
"import tempfile\n",
"\n",
"from IPython import display\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AVTs0jDd1b24"
},
"source": [
"The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NHT1iiIiBzlC"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"\n",
"# These are not in the tfm public API for v2.9. They will be available in v2.10\n",
"from official.vision.serving import export_saved_model_lib\n",
"import official.core.train_lib"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aKv3wdqkQ8FU"
},
"source": [
"## Configure the ResNet-18 model for the Cifar-10 dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5iN8mHEJjKYE"
},
"source": [
"The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
"\n",
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
"\n",
"Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1M77f88Dj2Td"
},
"outputs": [],
"source": [
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"tfds_name = 'cifar10'\n",
"ds_info = tfds.builder(tfds_name ).info\n",
"ds_info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U6PVwXA-j3E7"
},
"source": [
"Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YWI7faVStQaV"
},
"outputs": [],
"source": [
"# Configure model\n",
"exp_config.task.model.num_classes = 10\n",
"exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
"exp_config.task.model.backbone.resnet.model_id = 18\n",
"\n",
"# Configure training and testing data\n",
"batch_size = 128\n",
"\n",
"exp_config.task.train_data.input_path = ''\n",
"exp_config.task.train_data.tfds_name = tfds_name\n",
"exp_config.task.train_data.tfds_split = 'train'\n",
"exp_config.task.train_data.global_batch_size = batch_size\n",
"\n",
"exp_config.task.validation_data.input_path = ''\n",
"exp_config.task.validation_data.tfds_name = tfds_name\n",
"exp_config.task.validation_data.tfds_split = 'test'\n",
"exp_config.task.validation_data.global_batch_size = batch_size\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DE3ggKzzTD56"
},
"source": [
"Adjust the trainer configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "inE_-4UGkLud"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'GPU'\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'TPU'\n",
"else:\n",
" print('Running on CPU is slow, so only train for a few steps.')\n",
" device = 'CPU'\n",
"\n",
"if device=='CPU':\n",
" train_steps = 20\n",
" exp_config.trainer.steps_per_loop = 5\n",
"else:\n",
" train_steps=5000\n",
" exp_config.trainer.steps_per_loop = 100\n",
"\n",
"exp_config.trainer.summary_interval = 100\n",
"exp_config.trainer.checkpoint_interval = train_steps\n",
"exp_config.trainer.validation_interval = 1000\n",
"exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n",
"exp_config.trainer.train_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5mTcDnBiTOYD"
},
"source": [
"Print the modified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tuVfxSBCTK-y"
},
"outputs": [],
"source": [
"pprint.pprint(exp_config.as_dict())\n",
"\n",
"display.Javascript(\"google.colab.output.setIframeHeight('300px');\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w7_X0UHaRF2m"
},
"source": [
"Set up the distribution strategy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ykL14FIbTaSt"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" tf.tpu.experimental.initialize_tpu_system()\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"else:\n",
" print('Warning: this will be really slow.')\n",
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4k5YH5pTjaK"
},
"source": [
"Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
"\n",
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6MgYSH0PtUaW"
},
"outputs": [],
"source": [
"with distribution_strategy.scope():\n",
" model_dir = tempfile.mkdtemp()\n",
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n",
"\n",
"tf.keras.utils.plot_model(task.build_model(), show_shapes=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFXEZYdzBKoX"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" print()\n",
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
" print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yrwxnGDaRU0U"
},
"source": [
"## Visualize the training data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "683c255c6c52"
},
"source": [
"The dataloader applies a z-score normalization using \n",
"`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PdmOz2EC0Nx2"
},
"outputs": [],
"source": [
"plt.hist(images.numpy().flatten());"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7a8582ebde7b"
},
"source": [
"Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wq4Wq_CuDG3Q"
},
"outputs": [],
"source": [
"label_info = ds_info.features['label']\n",
"label_info.int2str(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8c652a6fdbcf"
},
"source": [
"Visualize a batch of the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZKfTxytf1l0d"
},
"outputs": [],
"source": [
"def show_batch(images, labels, predictions=None):\n",
" plt.figure(figsize=(10, 10))\n",
" min = images.numpy().min()\n",
" max = images.numpy().max()\n",
" delta = max - min\n",
"\n",
" for i in range(12):\n",
" plt.subplot(6, 6, i + 1)\n",
" plt.imshow((images[i]-min) / delta)\n",
" if predictions is None:\n",
" plt.title(label_info.int2str(labels[i]))\n",
" else:\n",
" if labels[i] == predictions[i]:\n",
" color = 'g'\n",
" else:\n",
" color = 'r'\n",
" plt.title(label_info.int2str(predictions[i]), color=color)\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xkA5h_RBtYYU"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" show_batch(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v_A9VnL2RbXP"
},
"source": [
"## Visualize the testing data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AXovuumW_I2z"
},
"source": [
"Visualize a batch of images from the validation dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ma-_Eb-nte9A"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10));\n",
"for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n",
" show_batch(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ihKJt2FHRi2N"
},
"source": [
"## Train and evaluate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0AFMNvYxtjXx"
},
"outputs": [],
"source": [
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
" distribution_strategy=distribution_strategy,\n",
" task=task,\n",
" mode='train_and_eval',\n",
" params=exp_config,\n",
" model_dir=model_dir,\n",
" run_post_eval=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCcHMQYhozmA"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(model, show_shapes=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7nVfxlBA8Gb"
},
"source": [
"Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0124f938a1b9"
},
"outputs": [],
"source": [
"for key, value in eval_logs.items():\n",
" print(f'{key:20}: {value.numpy():.3f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDys5bZ1zsml"
},
"source": [
"Run a batch of the processed training data through the model, and view the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GhI7zR-Uz1JT"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" predictions = model.predict(images)\n",
" predictions = tf.argmax(predictions, axis=-1)\n",
"\n",
"show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
"\n",
"if device=='CPU':\n",
" plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fkE9locGTBgt"
},
"source": [
"## Export a SavedModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9669d08c91af"
},
"source": [
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AQCFa7BvtmDg"
},
"outputs": [],
"source": [
"# Saving and exporting the trained model\n",
"export_saved_model_lib.export_inference_graph(\n",
" input_type='image_tensor',\n",
" batch_size=1,\n",
" input_image_size=[32, 32],\n",
" params=exp_config,\n",
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
" export_dir='./export/')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vVr6DxNqTyLZ"
},
"source": [
"Test the exported model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gP7nOvrftsB0"
},
"outputs": [],
"source": [
"# Importing SavedModel\n",
"imported = tf.saved_model.load('./export/')\n",
"model_fn = imported.signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GiOp2WVIUNUZ"
},
"source": [
"Visualize the predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BTRMrZQAN4mk"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for data in tfds.load('cifar10', split='test').batch(12).take(1):\n",
" predictions = []\n",
" for image in data['image']:\n",
" index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
" predictions.append(index)\n",
" show_batch(data['image'], data['label'], predictions)\n",
"\n",
" if device=='CPU':\n",
" plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "classification_with_model_garden.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
# Moved
These files have moved to:
https://github.com/tensorflow/models/blob/master/docs
\ No newline at end of file
...@@ -20,7 +20,7 @@ examples. ...@@ -20,7 +20,7 @@ examples.
* [`losses`](losses) contains common loss computation used in NLP tasks. * [`losses`](losses) contains common loss computation used in NLP tasks.
Please see the colab Please see the colab
[nlp_modeling_library_intro.ipynb](https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb) [NLP modeling library intro.ipynb](https://colab.sandbox.google.com/github/tensorflow/models/blob/master/docs/nlp/index.ipynb)
for how to build transformer-based NLP models using above primitives. for how to build transformer-based NLP models using above primitives.
Besides the pre-defined primitives, it also provides scaffold classes to allow Besides the pre-defined primitives, it also provides scaffold classes to allow
...@@ -43,7 +43,7 @@ custom hidden layer (which will replace the Transformer instantiation in the ...@@ -43,7 +43,7 @@ custom hidden layer (which will replace the Transformer instantiation in the
encoder). encoder).
Please see the colab Please see the colab
[customize_encoder.ipynb](https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb) [customize_encoder.ipynb](https://colab.sandbox.google.com/github/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb)
for how to use scaffold classes to build noval achitectures. for how to use scaffold classes to build noval achitectures.
BERT and ALBERT models in this repo are implemented using this library. BERT and ALBERT models in this repo are implemented using this library.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment