Commit 47bc1813 authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into add_multilevel_crop_and_resize

parents d8611151 b035a227
> :memo: A README.md template for releasing a paper code implementation to a GitHub repository. > :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
> >
> * Template version: 1.0.2020.125 > * Template version: 1.0.2020.170
> * Please modify sections depending on needs. > * Please modify sections depending on needs.
# Model name, Paper title, or Project Name # Model name, Paper title, or Project Name
> :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN) > :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
[![Paper](http://img.shields.io/badge/paper-arXiv.YYMM.NNNNN-B3181B.svg)](https://arxiv.org/abs/...) [![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
This repository is the official or unofficial implementation of the following paper. This repository is the official or unofficial implementation of the following paper.
...@@ -28,8 +28,8 @@ This repository is the official or unofficial implementation of the following pa ...@@ -28,8 +28,8 @@ This repository is the official or unofficial implementation of the following pa
> :memo: Provide maintainer information. > :memo: Provide maintainer information.
* Last name, First name ([@GitHub username](https://github.com/username)) * Full name ([@GitHub username](https://github.com/username))
* Last name, First name ([@GitHub username](https://github.com/username)) * Full name ([@GitHub username](https://github.com/username))
## Table of Contents ## Table of Contents
...@@ -37,8 +37,8 @@ This repository is the official or unofficial implementation of the following pa ...@@ -37,8 +37,8 @@ This repository is the official or unofficial implementation of the following pa
## Requirements ## Requirements
[![TensorFlow 2.1](https://img.shields.io/badge/tensorflow-2.1-brightgreen)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0) [![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
[![Python 3.6](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/downloads/release/python-360/) [![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
> :memo: Provide details of the software required. > :memo: Provide details of the software required.
> >
...@@ -54,6 +54,8 @@ pip install -r requirements.txt ...@@ -54,6 +54,8 @@ pip install -r requirements.txt
## Results ## Results
[![TensorFlow Hub](https://img.shields.io/badge/TF%20Hub-Models-FF6F00?logo=tensorflow)](https://tfhub.dev/...)
> :memo: Provide a table with results. (e.g., accuracy, latency) > :memo: Provide a table with results. (e.g., accuracy, latency)
> >
> * Provide links to the pre-trained models (checkpoint, SavedModel files). > * Provide links to the pre-trained models (checkpoint, SavedModel files).
...@@ -104,6 +106,8 @@ python3 ... ...@@ -104,6 +106,8 @@ python3 ...
## License ## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
> :memo: Place your license text in a file named LICENSE in the root of the repository. > :memo: Place your license text in a file named LICENSE in the root of the repository.
> >
> * Include information about your license. > * Include information about your license.
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# Welcome to the Model Garden for TensorFlow # Welcome to the Model Garden for TensorFlow
The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users can take full advantage of TensorFlow for their research and product development. The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users
can take full advantage of TensorFlow for their research and product development.
| Directory | Description | | Directory | Description |
|-----------|-------------| |-----------|-------------|
...@@ -10,20 +11,28 @@ The TensorFlow Model Garden is a repository with a number of different implement ...@@ -10,20 +11,28 @@ The TensorFlow Model Garden is a repository with a number of different implement
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers | | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 | | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
## [Announcements](../../wiki/Announcements) ## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News | | Date | News |
|------|------| |------|------|
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released | May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection | May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 | May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) | | March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## [Milestones](https://github.com/tensorflow/models/milestones)
| Date | Milestone |
|------|-----------|
| July 8, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
## Contributions ## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation) [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
If you want to contribute, please review the [contribution guidelines](../../wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
## License ## License
......
...@@ -6,13 +6,12 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -6,13 +6,12 @@ This repository provides a curated list of the GitHub repositories with machine
**Note**: Contributing companies or individuals are responsible for maintaining their repositories. **Note**: Contributing companies or individuals are responsible for maintaining their repositories.
## Models / Implementations ## Computer Vision
### Computer Vision ### Image Recognition
#### Image Recognition | Model | Paper | Features | Maintainer |
| Model | Reference (Paper) | Features | Maintainer | |-------|-------|----------|------------|
|-------|-------------------|----------|------------|
| [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) | | [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
...@@ -21,12 +20,13 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -21,12 +20,13 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) | | [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
#### Segmentation ### Segmentation
| Model | Reference (Paper) | &nbsp; &nbsp; &nbsp; Features &nbsp; &nbsp; &nbsp; | Maintainer |
|-------|-------------------|----------|------------| | Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) | | [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
| [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) | | [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
## Contributions ## Contributions
If you want to contribute, please review the [contribution guidelines](../../../wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
...@@ -19,9 +19,10 @@ In the near future, we will add: ...@@ -19,9 +19,10 @@ In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models:
More members in Transformer family More members in Transformer family
* Start-of-the-art image classification models: * State-of-the-art image classification models:
EfficientNet, MnasNet, and variants EfficientNet, MnasNet, and variants
* A set of excellent objection detection models. * State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents ## Table of Contents
...@@ -43,6 +44,7 @@ In the near future, we will add: ...@@ -43,6 +44,7 @@ In the near future, we will add:
|-------|-------------------| |-------|-------------------|
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) | | [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) | | [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
#### Object Detection and Segmentation #### Object Detection and Segmentation
...@@ -50,6 +52,8 @@ In the near future, we will add: ...@@ -50,6 +52,8 @@ In the near future, we will add:
|-------|-------------------| |-------|-------------------|
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
......
...@@ -271,6 +271,23 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -271,6 +271,23 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu' FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
params['architecture']['multilevel_features'] = 'identity'
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
params['train']['checkpoint']['path'] = ''
FLAGS.model_dir = self._get_model_dir(
'real_benchmark_2x2_tpu_spinenet_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -32,9 +32,9 @@ from official.vision.segmentation import unet_model as unet_model_lib ...@@ -32,9 +32,9 @@ from official.vision.segmentation import unet_model as unet_model_lib
UNET3D_MIN_ACCURACY = 0.90 UNET3D_MIN_ACCURACY = 0.90
UNET3D_MAX_ACCURACY = 0.98 UNET3D_MAX_ACCURACY = 0.98
UNET_TRAINING_FILES = 'unet_training_data_files' UNET_TRAINING_FILES = 'gs://mlcompass-data/unet3d/train_data/*'
UNET_EVAL_FILES = 'unet_eval_data_files' UNET_EVAL_FILES = 'gs://mlcompass-data/unet3d/eval_data/*'
UNET_MODEL_CONFIG_FILE = 'unet_model_config' UNET_MODEL_CONFIG_FILE = 'gs://mlcompass-data/unet3d/config/unet_config.yaml'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -4,64 +4,79 @@ ...@@ -4,64 +4,79 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "YN2ACivEPxgD" "id": "vXLA5InzXydn"
}, },
"source": [ "source": [
"## How-to Guide: Using a PIP package for fine-tuning a BERT model\n", "##### Copyright 2019 The TensorFlow Authors."
"\n", ]
"Authors: [Chen Chen](https://github.com/chenGitHuber), [Claire Yao](https://github.com/claireyao-fen)\n", },
"\n", {
"In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package." "cell_type": "code",
"execution_count": 0,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "RuRlpLL-X0R_"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "T7BBEc1-RNCQ" "id": "1mLJmVotXs64"
}, },
"source": [ "source": [
"## License\n", "# Fine-tuning a BERT model"
"\n",
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "Pf6xzoKjywY_" "id": "hYEwGTeCXnnX"
}, },
"source": [ "source": [
"## Learning objectives\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
"\n", " \u003ctd\u003e\n",
"In this Colab notebook, you will learn how to fine-tune a BERT model using the TensorFlow Model Garden PIP package." " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/tutorials/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "YHkmV89jRWkS" "id": "YN2ACivEPxgD"
}, },
"source": [ "source": [
"## Enable the GPU acceleration\n", "In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package.\n",
"Please enable GPU for better performance.\n", "\n",
"* Navigate to Edit.\n", "The pretrained BERT model this tutorial is based on is also available on [TensorFlow Hub](https://tensorflow.org/hub), to see how to use it refer to the [Hub Appendix](#hub_bert)"
"* Find Notebook settings.\n",
"* Select GPU from the \"Hardware Accelerator\" drop-down list, save it."
] ]
}, },
{ {
...@@ -71,7 +86,7 @@ ...@@ -71,7 +86,7 @@
"id": "s2d9S2CSSO1z" "id": "s2d9S2CSSO1z"
}, },
"source": [ "source": [
"##Install and import" "## Setup"
] ]
}, },
{ {
...@@ -83,7 +98,7 @@ ...@@ -83,7 +98,7 @@
"source": [ "source": [
"### Install the TensorFlow Model Garden pip package\n", "### Install the TensorFlow Model Garden pip package\n",
"\n", "\n",
"* tf-models-nightly is the nightly Model Garden package created daily automatically.\n", "* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically." "* pip will install all models and dependencies automatically."
] ]
}, },
...@@ -97,7 +112,8 @@ ...@@ -97,7 +112,8 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install tf-models-nightly" "!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
] ]
}, },
{ {
...@@ -107,7 +123,7 @@ ...@@ -107,7 +123,7 @@
"id": "U-7qPCjWUAyy" "id": "U-7qPCjWUAyy"
}, },
"source": [ "source": [
"### Import Tensorflow and other libraries" "### Imports"
] ]
}, },
{ {
...@@ -123,67 +139,176 @@ ...@@ -123,67 +139,176 @@
"import os\n", "import os\n",
"\n", "\n",
"import numpy as np\n", "import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"tfds.disable_progress_bar()\n",
"\n",
"from official.modeling import tf_utils\n", "from official.modeling import tf_utils\n",
"from official.nlp import optimization\n", "from official import nlp\n",
"from official.nlp.bert import configs as bert_configs\n", "from official.nlp import bert\n",
"from official.nlp.bert import tokenization\n", "\n",
"from official.nlp.data import classifier_data_lib\n", "# Load the required submodules\n",
"from official.nlp.modeling import losses\n", "import official.nlp.optimization\n",
"from official.nlp.modeling import models\n", "import official.nlp.bert.bert_models\n",
"from official.nlp.modeling import networks" "import official.nlp.bert.configs\n",
"import official.nlp.bert.run_classifier\n",
"import official.nlp.bert.tokenization\n",
"import official.nlp.data.classifier_data_lib\n",
"import official.nlp.modeling.losses\n",
"import official.nlp.modeling.models\n",
"import official.nlp.modeling.networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mbanlzTvJBsz"
},
"source": [
"### Resources"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PpW0x8TpR8DT"
},
"source": [
"This directory contains the configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "vzRHOLciR8eq"
},
"outputs": [],
"source": [
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12\"\n",
"tf.io.gfile.listdir(gs_folder_bert)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9uFskufsR2LT"
},
"source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "e0dAkUttJAzj"
},
"outputs": [],
"source": [
"hub_url_bert = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\""
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "C2drjD7OVCmh" "id": "Qv6abtRvH4xO"
}, },
"source": [ "source": [
"## Preprocess the raw data and output tf.record files" "## The data\n",
"For this example we used the [GLUE MRPC dataset from TFDS](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
"\n",
"This dataset is not set up so that it can be directly fed into the BERT model, so this section also handles the necessary preprocessing."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "qfjcKj5FYQOp" "id": "28DvUhC1YUiB"
}, },
"source": [ "source": [
"### Introduction of dataset\n", "### Get the dataset from TensorFlow Datasets\n",
"\n", "\n",
"The Microsoft Research Paraphrase Corpus (Dolan \u0026 Brockett, 2005) is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent.\n", "The Microsoft Research Paraphrase Corpus (Dolan \u0026 Brockett, 2005) is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent.\n",
"\n", "\n",
"* Number of labels: 2.\n", "* Number of labels: 2.\n",
"* Size of training dataset: 3668.\n", "* Size of training dataset: 3668.\n",
"* Size of evaluation dataset: 408.\n", "* Size of evaluation dataset: 408.\n",
"* Maximum sequence length of training and evaluation dataset: 128.\n", "* Maximum sequence length of training and evaluation dataset: 128.\n"
"* Please refer here for details: https://www.tensorflow.org/datasets/catalog/glue#gluemrpc" ]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ijikx5OsH9AT"
},
"outputs": [],
"source": [
"glue, info = tfds.load('glue/mrpc', with_info=True,\n",
" # It's small, load the whole dataset\n",
" batch_size=-1)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xf9zz4vLYXjr"
},
"outputs": [],
"source": [
"list(glue.keys())"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "28DvUhC1YUiB" "id": "ZgBg2r2nYT-K"
}, },
"source": [ "source": [
"### Get dataset from TensorFlow Datasets (TFDS)\n", "The `info` object describes the dataset and it's features:"
"\n", ]
"For example, we used the GLUE MRPC dataset from TFDS: https://www.tensorflow.org/datasets/catalog/glue#gluemrpc." },
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "IQrHxv7W7jH5"
},
"outputs": [],
"source": [
"info.features"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "4PhRLWh9jaXp" "id": "vhsVWYNxazz5"
}, },
"source": [ "source": [
"### Preprocess the data and write to TensorFlow record file\n", "The two classes are:"
"\n"
] ]
}, },
{ {
...@@ -192,43 +317,21 @@ ...@@ -192,43 +317,21 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "FhcMdzsrjWzG" "id": "n0gfc_VTayfQ"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12\"\n", "info.features['label'].names"
"\n",
"# Set up tokenizer to generate Tensorflow dataset\n",
"tokenizer = tokenization.FullTokenizer(\n",
" vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"), do_lower_case=True)\n",
"\n",
"# Set up processor to generate Tensorflow dataset\n",
"processor = classifier_data_lib.TfdsProcessor(\n",
" tfds_params=\"dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2\",\n",
" process_text_fn=tokenization.convert_to_unicode)\n",
"\n",
"# Set up output of training and evaluation Tensorflow dataset\n",
"train_data_output_path=\"./mrpc_train.tf_record\"\n",
"eval_data_output_path=\"./mrpc_eval.tf_record\"\n",
"\n",
"# Generate and save training data into a tf record file\n",
"input_meta_data = classifier_data_lib.generate_tf_record_from_data_file(\n",
" processor=processor,\n",
" data_dir=None, # It is `None` because data is from tfds, not local dir.\n",
" tokenizer=tokenizer,\n",
" train_data_output_path=train_data_output_path,\n",
" eval_data_output_path=eval_data_output_path,\n",
" max_seq_length=128)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "dbJ76vSJj77j" "id": "38zJcap6xkbC"
}, },
"source": [ "source": [
"### Create tf.dataset for training and evaluation\n" "Here is one example from the training set:"
] ]
}, },
{ {
...@@ -237,82 +340,38 @@ ...@@ -237,82 +340,38 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "gCvaLLAxPuMc" "id": "xON_i6SkwApW"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def create_classifier_dataset(file_path, seq_length, batch_size, is_training):\n", "glue_train = glue['train']\n",
" \"\"\"Creates input dataset from (tf)records files for train/eval.\"\"\"\n",
" dataset = tf.data.TFRecordDataset(file_path)\n",
" if is_training:\n",
" dataset = dataset.shuffle(100)\n",
" dataset = dataset.repeat()\n",
"\n",
" def decode_record(record):\n",
" name_to_features = {\n",
" 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'label_ids': tf.io.FixedLenFeature([], tf.int64),\n",
" }\n",
" return tf.io.parse_single_example(record, name_to_features)\n",
"\n",
" def _select_data_from_record(record):\n",
" x = {\n",
" 'input_word_ids': record['input_ids'],\n",
" 'input_mask': record['input_mask'],\n",
" 'input_type_ids': record['segment_ids']\n",
" }\n",
" y = record['label_ids']\n",
" return (x, y)\n",
"\n",
" dataset = dataset.map(decode_record,\n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
" dataset = dataset.map(\n",
" _select_data_from_record,\n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
" dataset = dataset.batch(batch_size, drop_remainder=is_training)\n",
" dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
" return dataset\n",
"\n",
"# Set up batch sizes\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"# Return Tensorflow dataset\n",
"training_dataset = create_classifier_dataset(\n",
" train_data_output_path,\n",
" input_meta_data['max_seq_length'],\n",
" batch_size,\n",
" is_training=True)\n",
"\n", "\n",
"evaluation_dataset = create_classifier_dataset(\n", "for key, value in glue_train.items():\n",
" eval_data_output_path,\n", " print(f\"{key:9s}: {value[0].numpy()}\")"
" input_meta_data['max_seq_length'],\n",
" eval_batch_size,\n",
" is_training=False)\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "Efrj3Cn1kLAp" "id": "9fbTyfJpNr7x"
}, },
"source": [ "source": [
"## Create, compile and train the model" "### The BERT tokenizer"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "96ldxDSwkVkj" "id": "wqeN54S61ZKQ"
}, },
"source": [ "source": [
"### Construct a Bert Model\n", "To fine tune a pre-trained model you need to be sure that you're using exactly the same tokenization, vocabulary, and index mapping as you used during training.\n",
"\n", "\n",
"Here, a Bert Model is constructed from the json file with parameters. The bert_config defines the core Bert Model, which is a Keras model to predict the outputs of *num_classes* from the inputs with maximum sequence length *max_seq_length*. " "The BERT tokenizer used in this tutorial is written in pure Python (It's not built out of TensorFlow ops). So you can't just plug it into your model as a `keras.layer` like you can with `preprocessing.TextVectorization`.\n",
"\n",
"The following code rebuilds the tokenizer that was used by the base model:"
] ]
}, },
{ {
...@@ -321,44 +380,26 @@ ...@@ -321,44 +380,26 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "Qgajw8WPYzJZ" "id": "idxyhmrCQcw5"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n", "# Set up tokenizer to generate Tensorflow dataset\n",
"bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)\n", "tokenizer = bert.tokenization.FullTokenizer(\n",
"\n", " vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"),\n",
"bert_encoder = networks.TransformerEncoder(vocab_size=bert_config.vocab_size,\n", " do_lower_case=True)\n",
" hidden_size=bert_config.hidden_size,\n", "\n",
" num_layers=bert_config.num_hidden_layers,\n", "print(\"Vocab size:\", len(tokenizer.vocab))"
" num_attention_heads=bert_config.num_attention_heads,\n",
" intermediate_size=bert_config.intermediate_size,\n",
" activation=tf_utils.get_activation(bert_config.hidden_act),\n",
" dropout_rate=bert_config.hidden_dropout_prob,\n",
" attention_dropout_rate=bert_config.attention_probs_dropout_prob,\n",
" sequence_length=input_meta_data['max_seq_length'],\n",
" max_sequence_length=bert_config.max_position_embeddings,\n",
" type_vocab_size=bert_config.type_vocab_size,\n",
" embedding_width=bert_config.embedding_size,\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=bert_config.initializer_range))\n",
"\n",
"classifier_model = models.BertClassifier(\n",
" bert_encoder,\n",
" num_classes=input_meta_data['num_labels'],\n",
" dropout_rate=bert_config.hidden_dropout_prob,\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=bert_config.initializer_range))"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "pkSq1wbNXBaa" "id": "zYHDSquU2lDU"
}, },
"source": [ "source": [
"### Initialize the encoder from a pretrained model" "Tokenize a sentence:"
] ]
}, },
{ {
...@@ -367,26 +408,40 @@ ...@@ -367,26 +408,40 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "X6N9NEqfXJCx" "id": "L_OfOYPg853R"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"checkpoint = tf.train.Checkpoint(model=bert_encoder)\n", "tokens = tokenizer.tokenize(\"Hello TensorFlow!\")\n",
"checkpoint.restore(\n", "print(tokens)\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()" "ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"print(ids)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "115caFLMk-_l" "id": "kkAXLtuyWWDI"
}, },
"source": [ "source": [
"### Set up an optimizer for the model\n", "### Preprocess the data\n",
"\n", "\n",
"BERT model adopts the Adam optimizer with weight decay.\n", "The section manually preprocessed the dataset into the format expected by the model.\n",
"It also employs a learning rate schedule that firstly warms up from 0 and then decays to 0." "\n",
"This dataset is small, so preprocessing can be done quickly and easily in memory. For larger datasets the `tf_models` library includes some tools for preprocessing and re-serializing a dataset. See [Appendix: Re-encoding a large dataset](#re_encoding_tools) for details."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "62UTWLQd9-LB"
},
"source": [
"#### Encode the sentences\n",
"\n",
"The model expects its two inputs sentences to be concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token:"
] ]
}, },
{ {
...@@ -395,45 +450,21 @@ ...@@ -395,45 +450,21 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "2Hf2rpRXk89N" "id": "bdL-dRNRBRJT"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Set up epochs and steps\n", "tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]'])"
"epochs = 3\n",
"train_data_size = input_meta_data['train_data_size']\n",
"steps_per_epoch = int(train_data_size / batch_size)\n",
"num_train_steps = steps_per_epoch * epochs\n",
"warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
"\n",
"# Create learning rate schedule that firstly warms up from 0 and they decy to 0.\n",
"lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
" initial_learning_rate=2e-5,\n",
" decay_steps=num_train_steps,\n",
" end_learning_rate=0)\n",
"lr_schedule = optimization.WarmUp(\n",
" initial_learning_rate=2e-5,\n",
" decay_schedule_fn=lr_schedule,\n",
" warmup_steps=warmup_steps)\n",
"optimizer = optimization.AdamWeightDecay(\n",
" learning_rate=lr_schedule,\n",
" weight_decay_rate=0.01,\n",
" beta_1=0.9,\n",
" beta_2=0.999,\n",
" epsilon=1e-6,\n",
" exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "OTNcA0O0nSq9" "id": "UrPktnqpwqie"
}, },
"source": [ "source": [
"### Define metric_fn and loss_fn\n", "Start by encoding all the sentences while appending a `[SEP]` token, and packing them into ragged-tensors:"
"\n",
"The metric is accuracy and we use sparse categorical cross-entropy as loss."
] ]
}, },
{ {
...@@ -442,27 +473,43 @@ ...@@ -442,27 +473,43 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "ELHjRp87nVNH" "id": "BR7BmtU498Bh"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def metric_fn():\n", "def encode_sentence(s):\n",
" return tf.keras.metrics.SparseCategoricalAccuracy(\n", " tokens = list(tokenizer.tokenize(s.numpy()))\n",
" 'accuracy', dtype=tf.float32)\n", " tokens.append('[SEP]')\n",
" return tokenizer.convert_tokens_to_ids(tokens)\n",
"\n", "\n",
"def classification_loss_fn(labels, logits):\n", "sentence1 = tf.ragged.constant([\n",
" return losses.weighted_sparse_categorical_crossentropy_loss(\n", " encode_sentence(s) for s in glue_train[\"sentence1\"]])\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n" "sentence2 = tf.ragged.constant([\n",
" encode_sentence(s) for s in glue_train[\"sentence2\"]])"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "has42aUdfky-"
},
"outputs": [],
"source": [
"print(\"Sentence1 shape:\", sentence1.shape.as_list())\n",
"print(\"Sentence2 shape:\", sentence2.shape.as_list())"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "78FEUOOEkoP0" "id": "MU9lTWy_xXbb"
}, },
"source": [ "source": [
"### Compile and train the model" "Now prepend a `[CLS]` token, and concatenate the ragged tensors to form a single `input_word_ids` tensor for each example. `RaggedTensor.to_tensor()` zero pads to the longest sequence."
] ]
}, },
{ {
...@@ -471,29 +518,46 @@ ...@@ -471,29 +518,46 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "nzi8hjeTQTRs" "id": "USD8uihw-g4J"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"classifier_model.compile(optimizer=optimizer,\n", "cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]\n",
" loss=classification_loss_fn,\n", "input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)\n",
" metrics=[metric_fn()])\n", "_ = plt.pcolormesh(input_word_ids.to_tensor())"
"classifier_model.fit(\n",
" x=training_dataset,\n",
" validation_data=evaluation_dataset,\n",
" steps_per_epoch=steps_per_epoch,\n",
" epochs=epochs,\n",
" validation_steps=int(input_meta_data['eval_data_size'] / eval_batch_size))"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "fVo_AnT0l26j" "id": "xmNv4l4k-dBZ"
},
"source": [
"#### Mask and input type"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DIWjNIKq-ldh"
},
"source": [
"The model expects two additional inputs:\n",
"\n",
"* The input mask\n",
"* The input type"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ulNZ4U96-8JZ"
}, },
"source": [ "source": [
"### Save the model" "The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
] ]
}, },
{ {
...@@ -502,21 +566,23 @@ ...@@ -502,21 +566,23 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "Nl5x6nElZqkP" "id": "EezOO9qj91kP"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"classifier_model.save('./saved_model', include_optimizer=False, save_format='tf')" "input_mask = tf.ones_like(input_word_ids).to_tensor()\n",
"\n",
"plt.pcolormesh(input_mask)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text", "colab_type": "text",
"id": "nWsE6yeyfW00" "id": "rxLenwAvCkBf"
}, },
"source": [ "source": [
"## Use the trained model to predict\n" "The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of. "
] ]
}, },
{ {
...@@ -525,13 +591,1223 @@ ...@@ -525,13 +591,1223 @@
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "vz7YJY2QYAjP" "id": "2CetH_5C9P2m"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"eval_predictions = classifier_model.predict(evaluation_dataset)\n", "type_cls = tf.zeros_like(cls)\n",
"for prediction in eval_predictions:\n", "type_s1 = tf.zeros_like(sentence1)\n",
" print(\"Predicted label id: %s\" % np.argmax(prediction))" "type_s2 = tf.ones_like(sentence2)\n",
"input_type_ids = tf.concat([type_cls, type_s1, type_s2], axis=-1).to_tensor()\n",
"\n",
"plt.pcolormesh(input_type_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "P5UBnCn8Ii6s"
},
"source": [
"#### Put it all together\n",
"\n",
"Collect the above text parsing code into a single function, and apply it to each split of the `glue/mrpc` dataset."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "sDGiWYPLEd5a"
},
"outputs": [],
"source": [
"def encode_sentence(s, tokenizer):\n",
" tokens = list(tokenizer.tokenize(s))\n",
" tokens.append('[SEP]')\n",
" return tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
"def bert_encode(glue_dict, tokenizer):\n",
" num_examples = len(glue_dict[\"sentence1\"])\n",
" \n",
" sentence1 = tf.ragged.constant([\n",
" encode_sentence(s, tokenizer)\n",
" for s in np.array(glue_dict[\"sentence1\"])])\n",
" sentence2 = tf.ragged.constant([\n",
" encode_sentence(s, tokenizer)\n",
" for s in np.array(glue_dict[\"sentence2\"])])\n",
"\n",
" cls = [tokenizer.convert_tokens_to_ids(['[CLS]'])]*sentence1.shape[0]\n",
" input_word_ids = tf.concat([cls, sentence1, sentence2], axis=-1)\n",
"\n",
" input_mask = tf.ones_like(input_word_ids).to_tensor()\n",
"\n",
" type_cls = tf.zeros_like(cls)\n",
" type_s1 = tf.zeros_like(sentence1)\n",
" type_s2 = tf.ones_like(sentence2)\n",
" input_type_ids = tf.concat(\n",
" [type_cls, type_s1, type_s2], axis=-1).to_tensor()\n",
"\n",
" inputs = {\n",
" 'input_word_ids': input_word_ids.to_tensor(),\n",
" 'input_mask': input_mask,\n",
" 'input_type_ids': input_type_ids}\n",
"\n",
" return inputs"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "yuLKxf6zHxw-"
},
"outputs": [],
"source": [
"glue_train = bert_encode(glue['train'], tokenizer)\n",
"glue_train_labels = glue['train']['label']\n",
"\n",
"glue_validation = bert_encode(glue['validation'], tokenizer)\n",
"glue_validation_labels = glue['validation']['label']\n",
"\n",
"glue_test = bert_encode(glue['test'], tokenizer)\n",
"glue_test_labels = glue['test']['label']"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7FC5aLVxKVKK"
},
"source": [
"Each subset of the data has been converted to a dictionary of features, and a set of labels. Each feature in the input dictionary has the same shape, and the number of labels should match:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jyjTdGpFhO_1"
},
"outputs": [],
"source": [
"for key, value in glue_train.items():\n",
" print(f'{key:15s} shape: {value.shape}')\n",
"\n",
"print(f'glue_train_labels shape: {glue_train_labels.shape}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FSwymsbkbLDA"
},
"source": [
"## The model"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Efrj3Cn1kLAp"
},
"source": [
"### Build the model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xxpOY5r2Ayq6"
},
"source": [
"The first step is to download the configuration for the pre-trained model.\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ujapVfZ_AKW7"
},
"outputs": [],
"source": [
"import json\n",
"\n",
"bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
"config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
"\n",
"bert_config = bert.configs.BertConfig.from_dict(config_dict)\n",
"\n",
"config_dict"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "96ldxDSwkVkj"
},
"source": [
"The `config` defines the core BERT Model, which is a Keras model to predict the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`.\n",
"\n",
"This function returns both the encoder and the classifier."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cH682__U0FBv"
},
"outputs": [],
"source": [
"bert_classifier, bert_encoder = bert.bert_models.classifier_model(\n",
" bert_config, num_labels=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "XqKp3-5GIZlw"
},
"source": [
"The classifier has three inputs and one output:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "bAQblMIjwkvx"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sFmVG4SKZAw8"
},
"source": [
"Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "VTjgPbp4ZDKo"
},
"outputs": [],
"source": [
"glue_batch = {key: val[:10] for key, val in glue_train.items()}\n",
"\n",
"bert_classifier(\n",
" glue_batch, training=True\n",
").numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Q0NTdwZsQK8n"
},
"source": [
"The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.\n",
"\n",
"Inspecting the encoder, we see its stack of `Transformer` layers connected to those same three inputs:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "8L__-erBwLIQ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mKAvkQc3heSy"
},
"source": [
"### Restore the encoder weights\n",
"\n",
"When built the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "97Ll2Gichd_Y"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=bert_encoder)\n",
"checkpoint.restore(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2oHOql35k3Dd"
},
"source": [
"Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). See the [Hub appendix](#hub_bert) for details. "
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "115caFLMk-_l"
},
"source": [
"### Set up the optimizer\n",
"\n",
"BERT adopts the Adam optimizer with weight decay (aka \"[AdamW](https://arxiv.org/abs/1711.05101)\").\n",
"It also employs a learning rate schedule that firstly warms up from 0 and then decays to 0."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "w8qXKRZuCwW4"
},
"outputs": [],
"source": [
"# Set up epochs and steps\n",
"epochs = 3\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"train_data_size = len(glue_train_labels)\n",
"steps_per_epoch = int(train_data_size / batch_size)\n",
"num_train_steps = steps_per_epoch * epochs\n",
"warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
"\n",
"# creates an optimizer with learning rate schedule\n",
"optimizer = nlp.optimization.create_optimizer(\n",
" 2e-5, num_train_steps=num_train_steps, num_warmup_steps=warmup_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pXRGxiRNEHS2"
},
"source": [
"This returns an `AdamWeightDecay` optimizer with the learning rate schedule set:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "eQNA16bhDpky"
},
"outputs": [],
"source": [
"type(optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xqu_K71fJQB8"
},
"source": [
"To see an example of how to customize the optimizer and it's schedule, see the [Optimizer schedule appendix](#optiizer_schedule)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "78FEUOOEkoP0"
},
"source": [
"### Train the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OTNcA0O0nSq9"
},
"source": [
"The metric is accuracy and we use sparse categorical cross-entropy as loss."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nzi8hjeTQTRs"
},
"outputs": [],
"source": [
"metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"\n",
"bert_classifier.compile(\n",
" optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)\n",
"\n",
"bert_classifier.fit(\n",
" glue_train, glue_train_labels,\n",
" validation_data=(glue_validation, glue_validation_labels),\n",
" batch_size=32,\n",
" epochs=epochs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IFtKFWbNKb0u"
},
"source": [
"Now run the fine-tuned model on a custom example to see that it works.\n",
"\n",
"Start by encoding some sentence pairs:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9ZoUgDUNJPz3"
},
"outputs": [],
"source": [
"my_examples = bert_encode(\n",
" glue_dict = {\n",
" 'sentence1':[\n",
" 'The rain in Spain falls mainly on the plain.',\n",
" 'Look I fine tuned BERT.'],\n",
" 'sentence2':[\n",
" 'It mostly rains on the flat lands of Spain.',\n",
" 'Is it working? This does not match.']\n",
" },\n",
" tokenizer=tokenizer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7ynJibkBRTJF"
},
"source": [
"The model should report class `1` \"match\" for the first example and class `0` \"no-match\" for the second:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "umo0ttrgRYIM"
},
"outputs": [],
"source": [
"result = bert_classifier(my_examples, training=False)\n",
"\n",
"result = tf.argmax(result).numpy()\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "utGl0M3aZCE4"
},
"outputs": [],
"source": [
"np.array(info.features['label'].names)[result]"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fVo_AnT0l26j"
},
"source": [
"### Save the model\n",
"\n",
"Often the goal of training a model is to _use_ it for something, so export the model and then restore it to be sure that it works."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Nl5x6nElZqkP"
},
"outputs": [],
"source": [
"export_dir='./saved_model'\n",
"tf.saved_model.save(bert_classifier, export_dir=export_dir)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "y_ACvKPsVUXC"
},
"outputs": [],
"source": [
"reloaded = tf.saved_model.load(export_dir)\n",
"reloaded_result = reloaded([my_examples['input_word_ids'],\n",
" my_examples['input_mask'],\n",
" my_examples['input_type_ids']], training=False)\n",
"\n",
"original_result = bert_classifier(my_examples, training=False)\n",
"\n",
"# The results are (nearly) identical:\n",
"print(original_result.numpy())\n",
"print()\n",
"print(reloaded_result.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eQceYqRFT_Eg"
},
"source": [
"## Appendix"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SaC1RlFawUpc"
},
"source": [
"\u003ca id=re_encoding_tools\u003e\u003c/a\u003e\n",
"### Re-encoding a large dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "CwUdjFBkzUgh"
},
"source": [
"This tutorial you re-encoded the dataset in memory, for clarity.\n",
"\n",
"This was only possible because `glue/mrpc` is a very small dataset. To deal with larger datasets `tf_models` library includes some tools for processing and re-encoding a dataset for efficient training."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2UTQrkyOT5wD"
},
"source": [
"The first step is to describe which features of the dataset should be transformed:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XQeDFOzYR9Z9"
},
"outputs": [],
"source": [
"processor = nlp.data.classifier_data_lib.TfdsProcessor(\n",
" tfds_params=\"dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2\",\n",
" process_text_fn=bert.tokenization.convert_to_unicode)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "XrFQbfErUWxa"
},
"source": [
"Then apply the transformation to generate new TFRecord files."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ymw7GOHpSHKU"
},
"outputs": [],
"source": [
"# Set up output of training and evaluation Tensorflow dataset\n",
"train_data_output_path=\"./mrpc_train.tf_record\"\n",
"eval_data_output_path=\"./mrpc_eval.tf_record\"\n",
"\n",
"max_seq_length = 128\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"# Generate and save training data into a tf record file\n",
"input_meta_data = (\n",
" nlp.data.classifier_data_lib.generate_tf_record_from_data_file(\n",
" processor=processor,\n",
" data_dir=None, # It is `None` because data is from tfds, not local dir.\n",
" tokenizer=tokenizer,\n",
" train_data_output_path=train_data_output_path,\n",
" eval_data_output_path=eval_data_output_path,\n",
" max_seq_length=max_seq_length))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "uX_Sp-wTUoRm"
},
"source": [
"Finally create `tf.data` input pipelines from those TFRecord files:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rkHxIK57SQ_r"
},
"outputs": [],
"source": [
"training_dataset = bert.run_classifier.get_dataset_fn(\n",
" train_data_output_path,\n",
" max_seq_length,\n",
" batch_size,\n",
" is_training=True)()\n",
"\n",
"evaluation_dataset = bert.run_classifier.get_dataset_fn(\n",
" eval_data_output_path,\n",
" max_seq_length,\n",
" eval_batch_size,\n",
" is_training=False)()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "stbaVouogvzS"
},
"source": [
"The resulting `tf.data.Datasets` return `(features, labels)` pairs, as expected by `keras.Model.fit`:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "gwhrlQl4gxVF"
},
"outputs": [],
"source": [
"training_dataset.element_spec"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dbJ76vSJj77j"
},
"source": [
"#### Create tf.data.Dataset for training and evaluation\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9J95LFRohiYw"
},
"source": [
"If you need to modify the data loading here is some code to get you started:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "gCvaLLAxPuMc"
},
"outputs": [],
"source": [
"def create_classifier_dataset(file_path, seq_length, batch_size, is_training):\n",
" \"\"\"Creates input dataset from (tf)records files for train/eval.\"\"\"\n",
" dataset = tf.data.TFRecordDataset(file_path)\n",
" if is_training:\n",
" dataset = dataset.shuffle(100)\n",
" dataset = dataset.repeat()\n",
"\n",
" def decode_record(record):\n",
" name_to_features = {\n",
" 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'label_ids': tf.io.FixedLenFeature([], tf.int64),\n",
" }\n",
" return tf.io.parse_single_example(record, name_to_features)\n",
"\n",
" def _select_data_from_record(record):\n",
" x = {\n",
" 'input_word_ids': record['input_ids'],\n",
" 'input_mask': record['input_mask'],\n",
" 'input_type_ids': record['segment_ids']\n",
" }\n",
" y = record['label_ids']\n",
" return (x, y)\n",
"\n",
" dataset = dataset.map(decode_record,\n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
" dataset = dataset.map(\n",
" _select_data_from_record,\n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
" dataset = dataset.batch(batch_size, drop_remainder=is_training)\n",
" dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rutkBadrhzdR"
},
"outputs": [],
"source": [
"# Set up batch sizes\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"# Return Tensorflow dataset\n",
"training_dataset = create_classifier_dataset(\n",
" train_data_output_path,\n",
" input_meta_data['max_seq_length'],\n",
" batch_size,\n",
" is_training=True)\n",
"\n",
"evaluation_dataset = create_classifier_dataset(\n",
" eval_data_output_path,\n",
" input_meta_data['max_seq_length'],\n",
" eval_batch_size,\n",
" is_training=False)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "59TVgt4Z7fuU"
},
"outputs": [],
"source": [
"training_dataset.element_spec"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QbklKt-w_CiI"
},
"source": [
"\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
"\n",
"### TFModels BERT on TFHub\n",
"\n",
"You can get [the BERT model](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) off the shelf from [TFHub](https://tensorflow.org/hub). It would not be hard to add a classification head on top of this `hub.KerasLayer`"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"# Note: 350MB download.\n",
"import tensorflow_hub as hub\n",
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iTzF574wivQv"
},
"source": [
"Test run it on a batch of data:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XEcYrCR45Uwo"
},
"outputs": [],
"source": [
"result = hub_encoder(\n",
" inputs=[glue_train['input_word_ids'][:10],\n",
" glue_train['input_mask'][:10],\n",
" glue_train['input_type_ids'][:10],],\n",
" training=False,\n",
")\n",
"\n",
"print(\"Pooled output shape:\", result[0].shape)\n",
"print(\"Sequence output shape:\", result[1].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cjojn8SmLSRI"
},
"source": [
"At this point it would be simple to add a classification head yourself.\n",
"\n",
"The `bert_models.classifier_model` function can also build a classifier onto the encoder from TensorFlow Hub:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9nTDaApyLR70"
},
"outputs": [],
"source": [
"hub_classifier, hub_encoder = bert.bert_models.classifier_model(\n",
" # Caution: Most of `bert_config` is ignored if you pass a hub url.\n",
" bert_config=bert_config, hub_module_url=hub_url_bert, num_labels=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xMJX3wV0_v7I"
},
"source": [
"The one downside to loading this model from TFHub is that the structure of internal keras layers is not restored. So it's more difficult to inspect or modify the model. The `TransformerEncoder` model is now a single layer:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "pD71dnvhM2QS"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(hub_classifier, show_shapes=True, dpi=64)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nLZD-isBzNKi"
},
"outputs": [],
"source": [
"try:\n",
" tf.keras.utils.plot_model(hub_encoder, show_shapes=True, dpi=64)\n",
" assert False\n",
"except Exception as e:\n",
" print(f\"{type(e).__name__}: {e}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ZxSqH0dNAgXV"
},
"source": [
"\u003ca id=\"model_builder_functions\"\u003e\u003c/a\u003e\n",
"\n",
"### Low level model building\n",
"\n",
"If you need a more control over the construction of the model it's worth noting that the `classifier_model` function used earlier is really just a thin wrapper over the `nlp.modeling.networks.TransformerEncoder` and `nlp.modeling.models.BertClassifier` classes. Just remember that if you start modifying the architecture it may not be correct or possible to reload the pre-trained checkpoint so you'll need to retrain from scratch."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0cgABEwDj06P"
},
"source": [
"Build the encoder:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5r_yqhBFSVEM"
},
"outputs": [],
"source": [
"transformer_config = config_dict.copy()\n",
"\n",
"# You need to rename a few fields to make this work:\n",
"transformer_config['attention_dropout_rate'] = transformer_config.pop('attention_probs_dropout_prob')\n",
"transformer_config['activation'] = tf_utils.get_activation(transformer_config.pop('hidden_act'))\n",
"transformer_config['dropout_rate'] = transformer_config.pop('hidden_dropout_prob')\n",
"transformer_config['initializer'] = tf.keras.initializers.TruncatedNormal(\n",
" stddev=transformer_config.pop('initializer_range'))\n",
"transformer_config['max_sequence_length'] = transformer_config.pop('max_position_embeddings')\n",
"transformer_config['num_layers'] = transformer_config.pop('num_hidden_layers')\n",
"\n",
"transformer_config"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rIO8MI7LLijh"
},
"outputs": [],
"source": [
"manual_encoder = nlp.modeling.networks.TransformerEncoder(**transformer_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4a4tFSg9krRi"
},
"source": [
"Restore the weights:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "X6N9NEqfXJCx"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=manual_encoder)\n",
"checkpoint.restore(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1BPiPO4ykuwM"
},
"source": [
"Test run it:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hlVdgJKmj389"
},
"outputs": [],
"source": [
"result = manual_encoder(my_examples, training=True)\n",
"\n",
"print(\"Sequence output shape:\", result[0].shape)\n",
"print(\"Pooled output shape:\", result[1].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "nJMXvVgJkyBv"
},
"source": [
"Wrap it in a classifier:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tQX57GJ6wkAb"
},
"outputs": [],
"source": [
"manual_classifier = nlp.modeling.models.BertClassifier(\n",
" bert_encoder,\n",
" num_classes=2,\n",
" dropout_rate=transformer_config['dropout_rate'],\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=bert_config.initializer_range))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kB-nBWhQk0dS"
},
"outputs": [],
"source": [
"manual_classifier(my_examples, training=True).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "E6AJlOSyIO1L"
},
"source": [
"\u003ca id=\"optiizer_schedule\"\u003e\u003c/a\u003e\n",
"\n",
"### Optimizers and schedules\n",
"\n",
"The optimizer used to train the model was created using the `nlp.optimization.create_optimizer` function:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "28Dv3BPRlFTD"
},
"outputs": [],
"source": [
"optimizer = nlp.optimization.create_optimizer(\n",
" 2e-5, num_train_steps=num_train_steps, num_warmup_steps=warmup_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "LRjcHr0UlT8c"
},
"source": [
"That high level wrapper sets up the learning rate schedules and the optimizer.\n",
"\n",
"The base learning rate schedule used here is a linear decay to zero over the training run:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MHY8K6kDngQn"
},
"outputs": [],
"source": [
"epochs = 3\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"train_data_size = len(glue_train_labels)\n",
"steps_per_epoch = int(train_data_size / batch_size)\n",
"num_train_steps = steps_per_epoch * epochs"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "wKIcSprulu3P"
},
"outputs": [],
"source": [
"decay_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
" initial_learning_rate=2e-5,\n",
" decay_steps=num_train_steps,\n",
" end_learning_rate=0)\n",
"\n",
"plt.plot([decay_schedule(n) for n in range(num_train_steps)])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IMTC_gfAl_PZ"
},
"source": [
"This, in turn is wrapped in a `WarmUp` schedule that linearly increases the learning rate to the target value over the first 10% of training:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "YRt3VTmBmCBY"
},
"outputs": [],
"source": [
"warmup_steps = num_train_steps * 0.1\n",
"\n",
"warmup_schedule = nlp.optimization.WarmUp(\n",
" initial_learning_rate=2e-5,\n",
" decay_schedule_fn=decay_schedule,\n",
" warmup_steps=warmup_steps)\n",
"\n",
"# The warmup overshoots, because it warms up to the `initial_learning_rate`\n",
"# following the original implementation. You can set\n",
"# `initial_learning_rate=decay_schedule(warmup_steps)` if you don't like the\n",
"# overshoot.\n",
"plt.plot([warmup_schedule(n) for n in range(num_train_steps)])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "l8D9Lv3Bn740"
},
"source": [
"Then create the `nlp.optimization.AdamWeightDecay` using that schedule, configured for the BERT model:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2Hf2rpRXk89N"
},
"outputs": [],
"source": [
"optimizer = nlp.optimization.AdamWeightDecay(\n",
" learning_rate=warmup_schedule,\n",
" weight_decay_rate=0.01,\n",
" epsilon=1e-6,\n",
" exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])"
] ]
} }
], ],
...@@ -539,8 +1815,10 @@ ...@@ -539,8 +1815,10 @@
"accelerator": "GPU", "accelerator": "GPU",
"colab": { "colab": {
"collapsed_sections": [], "collapsed_sections": [],
"name": "How-to Guide: Using a PIP package for fine-tuning a BERT model", "name": "fine_tuning_bert.ipynb",
"provenance": [] "private_outputs": true,
"provenance": [],
"toc_visible": true
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
......
...@@ -14,15 +14,18 @@ ...@@ -14,15 +14,18 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Defines the base task abstraction.""" """Defines the base task abstraction."""
import abc
import functools import functools
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import six
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
class Task(tf.Module): class Task(tf.Module):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -54,14 +57,13 @@ class Task(tf.Module): ...@@ -54,14 +57,13 @@ class Task(tf.Module):
""" """
pass pass
@abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates the model architecture.
Returns: Returns:
A model instance. A model instance.
""" """
# TODO(hongkuny): the base task should call network factory.
pass
def compile_model(self, def compile_model(self,
model: tf.keras.Model, model: tf.keras.Model,
...@@ -98,6 +100,7 @@ class Task(tf.Module): ...@@ -98,6 +100,7 @@ class Task(tf.Module):
model.test_step = functools.partial(validation_step, model=model) model.test_step = functools.partial(validation_step, model=model)
return model return model
@abc.abstractmethod
def build_inputs(self, def build_inputs(self,
params: cfg.DataConfig, params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None): input_context: Optional[tf.distribute.InputContext] = None):
...@@ -112,20 +115,19 @@ class Task(tf.Module): ...@@ -112,20 +115,19 @@ class Task(tf.Module):
Returns: Returns:
A nested structure of per-replica input functions. A nested structure of per-replica input functions.
""" """
pass
def build_losses(self, features, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses. """Standard interface to compute losses.
Args: Args:
features: optional feature/labels tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
""" """
del model_outputs, features del model_outputs, labels
if aux_losses is None: if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)] losses = [tf.constant(0.0, dtype=tf.float32)]
...@@ -139,29 +141,29 @@ class Task(tf.Module): ...@@ -139,29 +141,29 @@ class Task(tf.Module):
del training del training
return [] return []
def process_metrics(self, metrics, labels, outputs): def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API. """Process and update metrics. Called when using custom training loop API.
Args: Args:
metrics: a nested structure of metrics objects. metrics: a nested structure of metrics objects.
The return of function self.build_metrics. The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model. For example, output of the keras model built by self.build_model.
""" """
for metric in metrics: for metric in metrics:
metric.update_state(labels, outputs) metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, outputs): def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API. """Process and update compiled_metrics. call when using compile/fit API.
Args: Args:
compiled_metrics: the compiled metrics (model.compiled_metrics). compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors. labels: a tensor or a nested structure of tensors.
outputs: a tensor or a nested structure of tensors. model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model. For example, output of the keras model built by self.build_model.
""" """
compiled_metrics.update_state(labels, outputs) compiled_metrics.update_state(labels, model_outputs)
def train_step(self, def train_step(self,
inputs, inputs,
...@@ -187,7 +189,7 @@ class Task(tf.Module): ...@@ -187,7 +189,7 @@ class Task(tf.Module):
outputs = model(features, training=True) outputs = model(features, training=True)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
...@@ -231,7 +233,7 @@ class Task(tf.Module): ...@@ -231,7 +233,7 @@ class Task(tf.Module):
features, labels = inputs, inputs features, labels = inputs, inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses( loss = self.build_losses(
features=labels, model_outputs=outputs, aux_losses=model.losses) labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
...@@ -245,16 +247,57 @@ class Task(tf.Module): ...@@ -245,16 +247,57 @@ class Task(tf.Module):
"""Performs the forward step.""" """Performs the forward step."""
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step."""
pass
def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps."""
return {}
_REGISTERED_TASK_CLS = {} _REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file. # TODO(b/158268740): Move these outside the base class file.
def register_task_cls(task_config: cfg.TaskConfig) -> Task: # TODO(b/158741360): Add type annotations once pytype checks across modules.
"""Register ExperimentConfig factory method.""" def register_task_cls(task_config_cls):
return registry.register(_REGISTERED_TASK_CLS, task_config) """Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task_cls(task_config: cfg.TaskConfig) -> Task: # The user-visible get_task() is defined after classes have been registered.
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config) # TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls return task_cls
...@@ -162,19 +162,38 @@ class CallbacksConfig(base_config.Config): ...@@ -162,19 +162,38 @@ class CallbacksConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TrainerConfig(base_config.Config): class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
"""
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True train_steps: int = 0
train_tf_function: bool = True validation_steps: Optional[int] = None
eval_tf_function: bool = True validation_interval: int = 100
steps_per_loop: int = 1000 steps_per_loop: int = 1000
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
network: base_config.Config = None model: base_config.Config = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
...@@ -182,13 +201,9 @@ class TaskConfig(base_config.Config): ...@@ -182,13 +201,9 @@ class TaskConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class ExperimentConfig(base_config.Config): class ExperimentConfig(base_config.Config):
"""Top-level configuration.""" """Top-level configuration."""
mode: str = "train" # train, eval, train_and_eval.
task: TaskConfig = TaskConfig() task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig() trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig() runtime: RuntimeConfig = RuntimeConfig()
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 100
_REGISTERED_CONFIGS = {} _REGISTERED_CONFIGS = {}
......
...@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig):
adam: adam optimizer config. adam: adam optimizer config.
adamw: adam with weight decay. adamw: adam with weight decay.
lamb: lamb optimizer. lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
""" """
type: Optional[str] = None type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig() sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig() adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig() adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -40,6 +40,29 @@ class SGDConfig(base_config.Config): ...@@ -40,6 +40,29 @@ class SGDConfig(base_config.Config):
momentum: float = 0.0 momentum: float = 0.0
@dataclasses.dataclass
class RMSPropConfig(base_config.Config):
"""Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of
tf.keras.optimizers.RMSprop.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
centered: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class AdamConfig(base_config.Config): class AdamConfig(base_config.Config):
"""Configuration for Adam optimizer. """Configuration for Adam optimizer.
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Union from typing import Union
import tensorflow as tf import tensorflow as tf
...@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = { ...@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD, 'sgd': tf.keras.optimizers.SGD,
'adam': tf.keras.optimizers.Adam, 'adam': tf.keras.optimizers.Adam,
'adamw': nlp_optimization.AdamWeightDecay, 'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB 'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop
} }
LR_CLS = { LR_CLS = {
......
...@@ -15,84 +15,37 @@ ...@@ -15,84 +15,37 @@
# ============================================================================== # ==============================================================================
"""Tests for optimizer_factory.py.""" """Tests for optimizer_factory.py."""
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import optimizer_factory from official.modeling.optimization import optimizer_factory
from official.modeling.optimization.configs import optimization_config from official.modeling.optimization.configs import optimization_config
from official.nlp import optimization as nlp_optimization
class OptimizerFactoryTest(tf.test.TestCase):
def test_sgd_optimizer(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
}
}
expected_optimizer_config = {
'name': 'SGD',
'learning_rate': 0.1,
'decay': 0.0,
'momentum': 0.9,
'nesterov': False
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tf.keras.optimizers.SGD)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_adam_optimizer(self):
# Define adam optimizer with default values.
params = {
'optimizer': {
'type': 'adam'
}
}
expected_optimizer_config = tf.keras.optimizers.Adam().get_config()
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam) class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_adam_weight_decay_optimizer(self): @parameterized.parameters(
('sgd'),
('rmsprop'),
('adam'),
('adamw'),
('lamb'))
def test_optimizers(self, optimizer_type):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'adamw' 'type': optimizer_type
} }
} }
expected_optimizer_config = nlp_optimization.AdamWeightDecay().get_config() optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
opt_config = optimization_config.OptimizationConfig(params) expected_optimizer_config = optimizer_cls().get_config()
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, nlp_optimization.AdamWeightDecay)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_lamb_optimizer(self):
params = {
'optimizer': {
'type': 'lamb'
}
}
expected_optimizer_config = tfa_optimizers.LAMB().get_config()
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr) optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tfa_optimizers.LAMB) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_stepwise_lr_schedule(self): def test_stepwise_lr_schedule(self):
......
...@@ -173,3 +173,18 @@ def assert_rank(tensor, expected_rank, name=None): ...@@ -173,3 +173,18 @@ def assert_rank(tensor, expected_rank, name=None):
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not " "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
"equal to the expected tensor rank `%s`" % "equal to the expected tensor rank `%s`" %
(name, actual_rank, str(tensor.shape), str(expected_rank))) (name, actual_rank, str(tensor.shape), str(expected_rank)))
def safe_mean(losses):
"""Computes a safe mean of the losses.
Args:
losses: `Tensor` whose elements contain individual loss measurements.
Returns:
A scalar representing the mean of `losses`. If `num_present` is zero,
then zero is returned.
"""
total = tf.reduce_sum(losses)
num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
return tf.math.divide_no_nan(total, num_elements)
...@@ -25,7 +25,6 @@ import tensorflow_hub as hub ...@@ -25,7 +25,6 @@ import tensorflow_hub as hub
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.modeling import losses
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -67,22 +66,27 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
next_sentence_loss, name='next_sentence_loss', aggregation='mean') next_sentence_loss, name='next_sentence_loss', aggregation='mean')
def call(self, def call(self,
lm_output, lm_output_logits,
sentence_output, sentence_output_logits,
lm_label_ids, lm_label_ids,
lm_label_weights, lm_label_weights,
sentence_labels=None): sentence_labels=None):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32) lm_output_logits = tf.cast(lm_output_logits, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) lm_label_ids, lm_output_logits, from_logits=True)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
lm_denominator_loss)
if sentence_labels is not None: if sentence_labels is not None:
sentence_output = tf.cast(sentence_output, tf.float32) sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels=sentence_labels, predictions=sentence_output) sentence_labels, sentence_output_logits, from_logits=True)
sentence_loss = tf.reduce_mean(sentence_loss)
loss = mask_label_loss + sentence_loss loss = mask_label_loss + sentence_loss
else: else:
sentence_loss = None sentence_loss = None
...@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -92,8 +96,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
# TODO(hongkuny): Avoids the hack and switches add_loss. # TODO(hongkuny): Avoids the hack and switches add_loss.
final_loss = tf.fill(batch_shape, loss) final_loss = tf.fill(batch_shape, loss)
self._add_metrics(lm_output, lm_label_ids, lm_label_weights, self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
mask_label_loss, sentence_output, sentence_labels, mask_label_loss, sentence_output_logits, sentence_labels,
sentence_loss) sentence_loss)
return final_loss return final_loss
...@@ -228,11 +232,12 @@ def pretrain_model(bert_config, ...@@ -228,11 +232,12 @@ def pretrain_model(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq, num_token_predictions=max_predictions_per_seq,
initializer=initializer, initializer=initializer,
output='predictions') output='logits')
lm_output, sentence_output = pretrainer_model( outputs = pretrainer_model(
[input_word_ids, input_mask, input_type_ids, masked_lm_positions]) [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
lm_output = outputs['masked_lm']
sentence_output = outputs['classification']
pretrain_loss_layer = BertPretrainLossAndMetricLayer( pretrain_loss_layer = BertPretrainLossAndMetricLayer(
vocab_size=bert_config.vocab_size) vocab_size=bert_config.vocab_size)
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
......
...@@ -247,3 +247,39 @@ def create_squad_dataset(file_path, ...@@ -247,3 +247,39 @@ def create_squad_dataset(file_path,
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
def create_retrieval_dataset(file_path,
seq_length,
batch_size,
input_pipeline_context=None):
"""Creates input dataset from (tf)records files for scoring."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'int_iden': tf.io.FixedLenFeature([1], tf.int64),
}
dataset = single_file_dataset(file_path, name_to_features)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
def _select_data_from_record(record):
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['int_iden']
return (x, y)
dataset = dataset.map(
_select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
...@@ -111,6 +111,7 @@ def run_customized_training_loop( ...@@ -111,6 +111,7 @@ def run_customized_training_loop(
model_dir=None, model_dir=None,
train_input_fn=None, train_input_fn=None,
steps_per_epoch=None, steps_per_epoch=None,
num_eval_per_epoch=1,
steps_per_loop=None, steps_per_loop=None,
epochs=1, epochs=1,
eval_input_fn=None, eval_input_fn=None,
...@@ -144,6 +145,7 @@ def run_customized_training_loop( ...@@ -144,6 +145,7 @@ def run_customized_training_loop(
steps_per_epoch: Number of steps to run per epoch. At the end of each steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided. if evaluation dataset is provided.
num_eval_per_epoch: Number of evaluations per epoch.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every communication in eager context, training logs are printed every
steps_per_loop. steps_per_loop.
...@@ -158,16 +160,17 @@ def run_customized_training_loop( ...@@ -158,16 +160,17 @@ def run_customized_training_loop(
init_checkpoint: Optional checkpoint to load to `sub_model` returned by init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`. `model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_train_begin(), on_train_end(),
`on_epoch_begin()`, `on_epoch_end()` methods are invoked during on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
training. Note that some metrics may be missing from `logs`. `on_epoch_end()` methods are invoked during training.
Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy. should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by sub_model_export_name: If not None, will export `sub_model` returned by
`model_fn` into checkpoint files. The name of intermediate checkpoint `model_fn` into checkpoint files. The name of intermediate checkpoint
file is {sub_model_export_name}_step_{step}.ckpt and the last file is {sub_model_export_name}_step_{step}.ckpt and the last
checkpint's name is {sub_model_export_name}.ckpt; checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model`
if None, `sub_model` will not be exported as checkpoint. will not be exported as checkpoint.
explicit_allreduce: Whether to explicitly perform gradient allreduce, explicit_allreduce: Whether to explicitly perform gradient allreduce,
instead of relying on implicit allreduce in optimizer.apply_gradients(). instead of relying on implicit allreduce in optimizer.apply_gradients().
default is False. For now, if training using FP16 mixed precision, default is False. For now, if training using FP16 mixed precision,
...@@ -177,10 +180,10 @@ def run_customized_training_loop( ...@@ -177,10 +180,10 @@ def run_customized_training_loop(
pre_allreduce_callbacks: A list of callback functions that takes gradients pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new and model variables pairs as input, manipulate them, and returns a new
gradients and model variables paris. The callback functions will be gradients and model variables paris. The callback functions will be
invoked in the list order and before gradients are allreduced. invoked in the list order and before gradients are allreduced. With
With mixed precision training, the pre_allreduce_allbacks will be mixed precision training, the pre_allreduce_allbacks will be applied on
applied on scaled_gradients. Default is no callbacks. scaled_gradients. Default is no callbacks. Only used when
Only used when explicit_allreduce=True. explicit_allreduce=True.
post_allreduce_callbacks: A list of callback functions that takes post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback returns a new gradients and model variables paris. The callback
...@@ -208,6 +211,8 @@ def run_customized_training_loop( ...@@ -208,6 +211,8 @@ def run_customized_training_loop(
required_arguments = [ required_arguments = [
strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
] ]
steps_between_evals = int(steps_per_epoch / num_eval_per_epoch)
if [arg for arg in required_arguments if arg is None]: if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, ' raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'`steps_per_epoch` and `train_input_fn` are required ' '`steps_per_epoch` and `train_input_fn` are required '
...@@ -216,17 +221,17 @@ def run_customized_training_loop( ...@@ -216,17 +221,17 @@ def run_customized_training_loop(
if tf.config.list_logical_devices('TPU'): if tf.config.list_logical_devices('TPU'):
# One can't fully utilize a TPU with steps_per_loop=1, so in this case # One can't fully utilize a TPU with steps_per_loop=1, so in this case
# default users to a more useful value. # default users to a more useful value.
steps_per_loop = min(1000, steps_per_epoch) steps_per_loop = min(1000, steps_between_evals)
else: else:
steps_per_loop = 1 steps_per_loop = 1
logging.info('steps_per_loop not specified. Using steps_per_loop=%d', logging.info('steps_per_loop not specified. Using steps_per_loop=%d',
steps_per_loop) steps_per_loop)
if steps_per_loop > steps_per_epoch: if steps_per_loop > steps_between_evals:
logging.warning( logging.warning(
'steps_per_loop: %d is specified to be greater than ' 'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as' ' steps_between_evals: %d, we will use steps_between_evals as'
' steps_per_loop.', steps_per_loop, steps_per_epoch) ' steps_per_loop.', steps_per_loop, steps_between_evals)
steps_per_loop = steps_per_epoch steps_per_loop = steps_between_evals
assert tf.executing_eagerly() assert tf.executing_eagerly()
if run_eagerly: if run_eagerly:
...@@ -242,12 +247,9 @@ def run_customized_training_loop( ...@@ -242,12 +247,9 @@ def run_customized_training_loop(
raise ValueError( raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.') 'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean( eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
'training_loss', dtype=tf.float32)
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators, # To correctly place the model weights on accelerators,
...@@ -260,6 +262,9 @@ def run_customized_training_loop( ...@@ -260,6 +262,9 @@ def run_customized_training_loop(
raise ValueError('sub_model_export_name is specified as %s, but ' raise ValueError('sub_model_export_name is specified as %s, but '
'sub_model is None.' % sub_model_export_name) 'sub_model is None.' % sub_model_export_name)
callback_list = tf.keras.callbacks.CallbackList(
callbacks=custom_callbacks, model=model)
optimizer = model.optimizer optimizer = model.optimizer
if init_checkpoint: if init_checkpoint:
...@@ -270,8 +275,7 @@ def run_customized_training_loop( ...@@ -270,8 +275,7 @@ def run_customized_training_loop(
checkpoint.restore(init_checkpoint).assert_existing_objects_matched() checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
train_loss_metric = tf.keras.metrics.Mean( train_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
'training_loss', dtype=tf.float32)
eval_metrics = [metric_fn()] if metric_fn else [] eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by # If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation. # both train and evaluation.
...@@ -440,18 +444,20 @@ def run_customized_training_loop( ...@@ -440,18 +444,20 @@ def run_customized_training_loop(
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file: if latest_checkpoint_file:
logging.info( logging.info('Checkpoint file %s found and restoring from '
'Checkpoint file %s found and restoring from ' 'checkpoint', latest_checkpoint_file)
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file) checkpoint.restore(latest_checkpoint_file)
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
current_step = optimizer.iterations.numpy() current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_name = 'ctl_step_{step}.ckpt'
while current_step < total_training_steps: logs = {}
callback_list.on_train_begin()
while current_step < total_training_steps and not model.stop_training:
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1) callback_list.on_epoch_begin(
int(current_step / steps_per_epoch) + 1)
# Training loss/metric are taking average over steps inside micro # Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round. # training loop. We reset the their values before each round.
...@@ -461,7 +467,7 @@ def run_customized_training_loop( ...@@ -461,7 +467,7 @@ def run_customized_training_loop(
callback_list.on_batch_begin(current_step) callback_list.on_batch_begin(current_step)
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_between_evals, steps_per_loop)
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop # TODO(zongweiz): merge with train_steps once tf.while_loop
...@@ -470,11 +476,9 @@ def run_customized_training_loop( ...@@ -470,11 +476,9 @@ def run_customized_training_loop(
train_single_step(train_iterator) train_single_step(train_iterator)
else: else:
# Converts steps to a Tensor to avoid tf.function retracing. # Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator, train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
tf.convert_to_tensor(steps, dtype=tf.int32))
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
current_step += steps current_step += steps
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
# Updates training logging. # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % ( training_status = 'Train Step: %d/%d / loss = %s' % (
...@@ -492,8 +496,7 @@ def run_customized_training_loop( ...@@ -492,8 +496,7 @@ def run_customized_training_loop(
'learning_rate', 'learning_rate',
optimizer.learning_rate(current_step), optimizer.learning_rate(current_step),
step=current_step) step=current_step)
tf.summary.scalar( tf.summary.scalar(train_loss_metric.name, train_loss, step=current_step)
train_loss_metric.name, train_loss, step=current_step)
for metric in train_metrics + model.metrics: for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric) metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value) training_status += ' %s = %f' % (metric.name, metric_value)
...@@ -501,7 +504,11 @@ def run_customized_training_loop( ...@@ -501,7 +504,11 @@ def run_customized_training_loop(
summary_writer.flush() summary_writer.flush()
logging.info(training_status) logging.info(training_status)
if current_step % steps_per_epoch == 0: # If no need for evaluation, we only call on_batch_end with train_loss,
# this is to ensure we get granular global_step/sec on Tensorboard.
if current_step % steps_between_evals:
callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
else:
# Save a submodel with the step in the file name after each epoch. # Save a submodel with the step in the file name after each epoch.
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint( _save_checkpoint(
...@@ -514,7 +521,6 @@ def run_customized_training_loop( ...@@ -514,7 +521,6 @@ def run_customized_training_loop(
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
logs = _run_evaluation(current_step, logs = _run_evaluation(current_step,
...@@ -523,8 +529,15 @@ def run_customized_training_loop( ...@@ -523,8 +529,15 @@ def run_customized_training_loop(
eval_loss_metric.reset_states() eval_loss_metric.reset_states()
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric.reset_states() metric.reset_states()
# We add train_loss here rather than call on_batch_end twice to make
# sure that no duplicated values are generated.
logs['loss'] = train_loss
callback_list.on_batch_end(current_step - 1, logs)
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs) # Calls on_epoch_end after each real epoch ends to prevent mis-calculation
# of training steps.
if current_step % steps_per_epoch == 0:
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
...@@ -532,14 +545,11 @@ def run_customized_training_loop( ...@@ -532,14 +545,11 @@ def run_customized_training_loop(
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
logs = None
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
logs = _run_evaluation(current_step, logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs) callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
training_summary = { training_summary = {
'total_training_steps': total_training_steps, 'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric), 'train_loss': _float_metric_value(train_loss_metric),
...@@ -557,4 +567,6 @@ def run_customized_training_loop( ...@@ -557,4 +567,6 @@ def run_customized_training_loop(
if not _should_export_summary(strategy): if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir) tf.io.gfile.rmtree(summary_dir)
callback_list.on_train_end()
return model return model
...@@ -258,6 +258,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -258,6 +258,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
loss_fn=tf.keras.losses.categorical_crossentropy, loss_fn=tf.keras.losses.categorical_crossentropy,
model_dir=model_dir, model_dir=model_dir,
steps_per_epoch=20, steps_per_epoch=20,
num_eval_per_epoch=4,
steps_per_loop=10, steps_per_loop=10,
epochs=2, epochs=2,
train_input_fn=input_fn, train_input_fn=input_fn,
...@@ -269,14 +270,15 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -269,14 +270,15 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
run_eagerly=False) run_eagerly=False)
self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})]) self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
epoch_ends, epoch_end_infos = zip(*callback.epoch_end) epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
self.assertEqual(list(epoch_ends), [1, 2]) self.assertEqual(list(epoch_ends), [1, 2, 2])
for info in epoch_end_infos: for info in epoch_end_infos:
self.assertIn('accuracy', info) self.assertIn('accuracy', info)
self.assertEqual(callback.batch_begin, self.assertEqual(callback.batch_begin, [(0, {}), (5, {}), (10, {}),
[(0, {}), (10, {}), (20, {}), (30, {})]) (15, {}), (20, {}), (25, {}),
(30, {}), (35, {})])
batch_ends, batch_end_infos = zip(*callback.batch_end) batch_ends, batch_end_infos = zip(*callback.batch_end)
self.assertEqual(list(batch_ends), [9, 19, 29, 39]) self.assertEqual(list(batch_ends), [4, 9, 14, 19, 24, 29, 34, 39])
for info in batch_end_infos: for info in batch_end_infos:
self.assertIn('loss', info) self.assertIn('loss', info)
......
...@@ -61,7 +61,11 @@ def define_common_squad_flags(): ...@@ -61,7 +61,11 @@ def define_common_squad_flags():
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related. # Predict processing related.
flags.DEFINE_string('predict_file', None, flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.') 'SQuAD prediction json file path. '
'`predict` mode supports multiple files: one can use '
'wildcard to specify multiple files and it can also be '
'multiple file patterns separated by comma. Note that '
'`eval` mode only supports a single predict file.')
flags.DEFINE_bool( flags.DEFINE_bool(
'do_lower_case', True, 'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased ' 'Whether to lower case the input text. Should be True for uncased '
...@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, ...@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn return _dataset_fn
def predict_squad_customized(strategy, def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
input_meta_data, input_meta_data):
bert_config, """Gets a squad model to make predictions."""
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope(): with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
...@@ -188,6 +179,23 @@ def predict_squad_customized(strategy, ...@@ -188,6 +179,23 @@ def predict_squad_customized(strategy,
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model) checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial() checkpoint.restore(checkpoint_path).expect_partial()
return squad_model
def predict_squad_customized(strategy,
input_meta_data,
predict_tfrecord_path,
num_steps,
squad_model):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
@tf.function @tf.function
def predict_step(iterator): def predict_step(iterator):
...@@ -287,8 +295,8 @@ def train_squad(strategy, ...@@ -287,8 +295,8 @@ def train_squad(strategy,
post_allreduce_callbacks=[clip_by_global_norm_callback]) post_allreduce_callbacks=[clip_by_global_norm_callback])
def prediction_output_squad( def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint): predict_file, squad_model):
"""Makes predictions for a squad dataset.""" """Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride'] doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length'] max_query_length = input_meta_data['max_query_length']
...@@ -296,7 +304,7 @@ def prediction_output_squad( ...@@ -296,7 +304,7 @@ def prediction_output_squad(
version_2_with_negative = input_meta_data.get('version_2_with_negative', version_2_with_negative = input_meta_data.get('version_2_with_negative',
False) False)
eval_examples = squad_lib.read_squad_examples( eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file, input_file=predict_file,
is_training=False, is_training=False,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative)
...@@ -337,8 +345,7 @@ def prediction_output_squad( ...@@ -337,8 +345,7 @@ def prediction_output_squad(
num_steps = int(dataset_size / FLAGS.predict_batch_size) num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized( all_results = predict_squad_customized(
strategy, input_meta_data, bert_config, strategy, input_meta_data, eval_writer.filename, num_steps, squad_model)
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = ( all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output( squad_lib.postprocess_output(
...@@ -356,11 +363,14 @@ def prediction_output_squad( ...@@ -356,11 +363,14 @@ def prediction_output_squad(
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative): squad_lib, version_2_with_negative, file_prefix=''):
"""Save output to json files.""" """Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json') output_prediction_file = os.path.join(FLAGS.model_dir,
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json') '%spredictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json') output_nbest_file = os.path.join(FLAGS.model_dir,
'%snbest_predictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
'%snull_odds.json' % file_prefix)
logging.info('Writing predictions to: %s', (output_prediction_file)) logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file)) logging.info('Writing nbest to: %s', (output_nbest_file))
...@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, ...@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file) squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def _get_matched_files(input_path):
"""Returns all files that matches the input_path."""
input_patterns = input_path.strip().split(',')
all_matched_files = []
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
all_matched_files.extend(matched_files)
return sorted(all_matched_files)
def predict_squad(strategy, def predict_squad(strategy,
input_meta_data, input_meta_data,
tokenizer, tokenizer,
...@@ -379,11 +405,24 @@ def predict_squad(strategy, ...@@ -379,11 +405,24 @@ def predict_squad(strategy,
"""Get prediction results and evaluate them to hard drive.""" """Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None: if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, all_predict_files = _get_matched_files(FLAGS.predict_file)
bert_config, squad_lib, init_checkpoint) squad_model = get_squad_model_to_predict(strategy, bert_config,
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, init_checkpoint, input_meta_data)
input_meta_data.get('version_2_with_negative', False)) for idx, predict_file in enumerate(all_predict_files):
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, squad_lib, predict_file,
squad_model)
if len(all_predict_files) == 1:
file_prefix = ''
else:
# if predict_file is /path/xquad.ar.json, the `file_prefix` may be
# "xquad.ar-0-"
file_prefix = '%s-' % os.path.splitext(
os.path.basename(all_predict_files[idx]))[0]
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False),
file_prefix)
def eval_squad(strategy, def eval_squad(strategy,
...@@ -395,9 +434,17 @@ def eval_squad(strategy, ...@@ -395,9 +434,17 @@ def eval_squad(strategy,
"""Get prediction results and evaluate them against ground truth.""" """Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None: if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predict_files = _get_matched_files(FLAGS.predict_file)
if len(all_predict_files) != 1:
raise ValueError('`eval_squad` only supports one predict file, '
'but got %s' % all_predict_files)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
bert_config, squad_lib, init_checkpoint) squad_model)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False)) input_meta_data.get('version_2_with_negative', False))
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A multi-head BERT encoder network for pretraining.""" """Multi-head BERT encoder network with classification heads.
Includes configurations and instantiation methods.
"""
from typing import List, Optional, Text from typing import List, Optional, Text
import dataclasses import dataclasses
...@@ -24,7 +27,6 @@ from official.modeling.hyperparams import base_config ...@@ -24,7 +27,6 @@ from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
...@@ -47,43 +49,34 @@ class BertPretrainerConfig(base_config.Config): ...@@ -47,43 +49,34 @@ class BertPretrainerConfig(base_config.Config):
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list) cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
def instantiate_from_cfg( def instantiate_classification_heads_from_cfgs(
cls_head_configs: List[ClsHeadConfig]) -> List[layers.ClassificationHead]:
return [
layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
] if cls_head_configs else []
def instantiate_bertpretrainer_from_cfg(
config: BertPretrainerConfig, config: BertPretrainerConfig,
encoder_network: Optional[tf.keras.Model] = None): encoder_network: Optional[tf.keras.Model] = None
) -> bert_pretrainer.BertPretrainerV2:
"""Instantiates a BertPretrainer from the config.""" """Instantiates a BertPretrainer from the config."""
encoder_cfg = config.encoder encoder_cfg = config.encoder
if encoder_network is None: if encoder_network is None:
encoder_network = networks.TransformerEncoder( encoder_network = encoders.instantiate_encoder_from_cfg(encoder_cfg)
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range))
if config.cls_heads:
classification_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
]
else:
classification_heads = []
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens, config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation), mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
classification_heads=classification_heads) classification_heads=instantiate_classification_heads_from_cfgs(
config.cls_heads))
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig): class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task.""" """Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = "" input_path: str = ""
global_batch_size: int = 512 global_batch_size: int = 512
is_training: bool = True is_training: bool = True
...@@ -95,15 +88,15 @@ class BertPretrainDataConfig(cfg.DataConfig): ...@@ -95,15 +88,15 @@ class BertPretrainDataConfig(cfg.DataConfig):
@dataclasses.dataclass @dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig): class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task.""" """Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = "" input_path: str = ""
global_batch_size: int = 512 global_batch_size: int = 512
is_training: bool = False is_training: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class BertSentencePredictionDataConfig(cfg.DataConfig): class SentencePredictionDataConfig(cfg.DataConfig):
"""Data of sentence prediction dataset.""" """Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = "" input_path: str = ""
global_batch_size: int = 32 global_batch_size: int = 32
is_training: bool = True is_training: bool = True
...@@ -111,10 +104,55 @@ class BertSentencePredictionDataConfig(cfg.DataConfig): ...@@ -111,10 +104,55 @@ class BertSentencePredictionDataConfig(cfg.DataConfig):
@dataclasses.dataclass @dataclasses.dataclass
class BertSentencePredictionDevDataConfig(cfg.DataConfig): class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev data of MNLI sentence prediction dataset.""" """Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = "" input_path: str = ""
global_batch_size: int = 32 global_batch_size: int = 32
is_training: bool = False is_training: bool = False
seq_length: int = 128 seq_length: int = 128
drop_remainder: bool = False drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class QADevDataConfig(cfg.DataConfig):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path: str = ""
input_preprocessed_data_path: str = ""
version_2_with_negative: bool = False
doc_stride: int = 128
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
query_length: int = 64
drop_remainder: bool = False
vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment