Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
> :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
>
> * Template version: 1.0.2020.125
> * Template version: 1.0.2020.170
> * Please modify sections depending on needs.
# Model name, Paper title, or Project Name
> :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
[![Paper](http://img.shields.io/badge/paper-arXiv.YYMM.NNNNN-B3181B.svg)](https://arxiv.org/abs/...)
[![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
This repository is the official or unofficial implementation of the following paper.
......@@ -28,8 +28,8 @@ This repository is the official or unofficial implementation of the following pa
> :memo: Provide maintainer information.
* Last name, First name ([@GitHub username](https://github.com/username))
* Last name, First name ([@GitHub username](https://github.com/username))
* Full name ([@GitHub username](https://github.com/username))
* Full name ([@GitHub username](https://github.com/username))
## Table of Contents
......@@ -37,8 +37,8 @@ This repository is the official or unofficial implementation of the following pa
## Requirements
[![TensorFlow 2.1](https://img.shields.io/badge/tensorflow-2.1-brightgreen)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
[![Python 3.6](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/downloads/release/python-360/)
[![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
[![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
> :memo: Provide details of the software required.
>
......@@ -54,6 +54,8 @@ pip install -r requirements.txt
## Results
[![TensorFlow Hub](https://img.shields.io/badge/TF%20Hub-Models-FF6F00?logo=tensorflow)](https://tfhub.dev/...)
> :memo: Provide a table with results. (e.g., accuracy, latency)
>
> * Provide links to the pre-trained models (checkpoint, SavedModel files).
......@@ -104,6 +106,8 @@ python3 ...
## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
> :memo: Place your license text in a file named LICENSE in the root of the repository.
>
> * Include information about your license.
......
......@@ -2,28 +2,34 @@
# Welcome to the Model Garden for TensorFlow
The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users can take full advantage of TensorFlow for their research and product development.
The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users
can take full advantage of TensorFlow for their research and product development.
| Directory | Description |
|-----------|-------------|
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](../../wiki/Announcements)
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News |
|------|------|
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1
| July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released |
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection |
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
If you want to contribute, please review the [contribution guidelines](../../wiki/How-to-contribute).
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
## License
......
......@@ -6,13 +6,12 @@ This repository provides a curated list of the GitHub repositories with machine
**Note**: Contributing companies or individuals are responsible for maintaining their repositories.
## Models / Implementations
## Computer Vision
### Computer Vision
### Image Recognition
#### Image Recognition
| Model | Reference (Paper) | Features | Maintainer |
|-------|-------------------|----------|------------|
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
......@@ -21,12 +20,21 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
#### Segmentation
| Model | Reference (Paper) | &nbsp; &nbsp; &nbsp; Features &nbsp; &nbsp; &nbsp; | Maintainer |
|-------|-------------------|----------|------------|
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Segmentation
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
| [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
## Contributions
If you want to contribute, please review the [contribution guidelines](../../../wiki/How-to-contribute).
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
......@@ -17,11 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models.
In the near future, we will add:
* State-of-the-art language understanding models:
More members in Transformer family
* Start-of-the-art image classification models:
EfficientNet, MnasNet, and variants
* A set of excellent objection detection models.
* State-of-the-art language understanding models.
* State-of-the-art image classification models.
* State-of-the-art objection detection and instance segmentation models.
## Table of Contents
......@@ -43,6 +41,7 @@ In the near future, we will add:
|-------|-------------------|
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
#### Object Detection and Segmentation
......@@ -50,6 +49,8 @@ In the near future, we will add:
|-------|-------------------|
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing
......
......@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 2x2 TPU for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 2x2 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 4x4 TPU for 10000 steps."""
......@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 4x4 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 8x8 TPU for 10000 steps."""
......
......@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
return os.path.join(self.output_dir, folder_name)
class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 (classifier_trainer) benchmarks."""
class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Classifier Trainer benchmarks."""
def __init__(self, output_dir=None, default_flags=None,
def __init__(self, model, output_dir=None, default_flags=None,
tpu=None, dataset_builder='records', train_epochs=1,
train_steps=110, data_dir=None):
flag_methods = [classifier_trainer.define_classifier_flags]
self.model = model
self.dataset_builder = dataset_builder
self.train_epochs = train_epochs
self.train_steps = train_steps
self.data_dir = data_dir
super(Resnet50KerasClassifierBenchmarkBase, self).__init__(
super(KerasClassifierBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags,
......@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None):
"""Runs and reports the benchmark given the provided configuration."""
FLAGS.model_type = 'resnet'
FLAGS.model_type = self.model
FLAGS.dataset = 'imagenet'
FLAGS.mode = 'train_and_eval'
FLAGS.data_dir = self.data_dir
......@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
# input skip_steps.
warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark(
super(KerasClassifierBenchmarkBase, self)._report_benchmark(
stats,
wall_time_sec,
total_batch_size=total_batch_size,
......@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='mirrored',
per_replica_batch_size=256,
gpu_thread_mode='gpu_private',
dataset_num_private_threads=48,
steps=310)
dataset_num_private_threads=48)
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
"""Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
......@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_2x2_tpu_bf16_mlir(self):
"""Test Keras model with 2x2 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_4x4_tpu_bf16_mlir(self):
"""Test Keras model with 4x4 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=32,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16."""
self._setup()
......@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
per_replica_batch_size=64)
def fill_report_object(self, stats):
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object(
super(KerasClassifierBenchmarkBase, self).fill_report_object(
stats,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
......@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
log_steps=FLAGS.log_steps)
class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
class Resnet50KerasBenchmarkSynth(KerasClassifierBenchmarkBase):
"""Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
......@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu,
model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='synthetic', train_epochs=1, train_steps=110)
class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
class Resnet50KerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
......@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu,
model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""EfficientNet real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
data_dir = os.path.join(root_data_dir, 'imagenet')
def_flags = {}
def_flags['log_steps'] = 10
super(EfficientNetKerasBenchmarkReal, self).__init__(
model='efficientnet', output_dir=output_dir, default_flags=def_flags,
tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
"""Resnet50 real data (stored in remote storage) benchmark tests."""
......
......@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
**kwargs):
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__(
output_dir=output_dir,
default_flags=self.default_flags,
flag_methods=self.flag_methods)
flag_methods=self.flag_methods,
**kwargs)
def _report_benchmark(self,
stats,
......@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
def __init__(self, output_dir=None, default_flags=None, **kwargs):
flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
default_flags=default_flags,
**kwargs)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
......@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False
FLAGS.data_dir = 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def benchmark_2x2_tpu_bf16(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_bf16_mlir(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16_mlir')
self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self):
......@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common()
FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
......@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common()
FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16_mlir')
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark()
......@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20
def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=def_flags, **kwargs)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
......@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20
def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=def_flags, **kwargs)
if __name__ == '__main__':
......
......@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi
# pylint: enable=line-too-long
class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
class BenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
"""Base class to hold methods common to test classes."""
def __init__(self, **kwargs):
super(DetectionBenchmarkBase, self).__init__(**kwargs)
super(BenchmarkBase, self).__init__(**kwargs)
self.timer_callback = None
def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap,
......@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
extras={'flags': flags_str})
class RetinanetBenchmarkBase(DetectionBenchmarkBase):
class DetectionBenchmarkBase(BenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, **kwargs):
......@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
self.eval_data_path = COCO_EVAL_DATA
self.eval_json_path = COCO_EVAL_JSON
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
super(RetinanetBenchmarkBase, self).__init__(**kwargs)
super(DetectionBenchmarkBase, self).__init__(**kwargs)
def _run_detection_main(self):
"""Starts detection job."""
......@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase):
class DetectionAccuracy(DetectionBenchmarkBase):
"""Accuracy test for RetinaNet model.
Tests RetinaNet detection task model accuracy. The naming
......@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format.
"""
def __init__(self, model, **kwargs):
self.model = model
super(DetectionAccuracy, self).__init__(**kwargs)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
params,
......@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap=0.35,
do_eval=True,
warmup=1):
"""Starts RetinaNet accuracy benchmark test."""
"""Starts Detection accuracy benchmark test."""
FLAGS.params_override = json.dumps(params)
# Need timer callback to measure performance
self.timer_callback = keras_utils.TimeHistory(
......@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap, warmup)
def _setup(self):
super(RetinanetAccuracy, self)._setup()
FLAGS.model = 'retinanet'
super(DetectionAccuracy, self)._setup()
FLAGS.model = self.model
def _params(self):
return {
......@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
self._run_and_report_benchmark(params)
class RetinanetBenchmarkReal(RetinanetAccuracy):
"""Short benchmark performance tests for RetinaNet model.
class DetectionBenchmarkReal(DetectionAccuracy):
"""Short benchmark performance tests for a detection model.
Tests RetinaNet performance in different GPU configurations.
Tests detection performance in different accelerator configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def _setup(self):
super(RetinanetBenchmarkReal, self)._setup()
super(DetectionBenchmarkReal, self)._setup()
# Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1
@flagsaver.flagsaver
def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs."""
"""Run detection model accuracy test with 8 GPUs."""
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
......@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU."""
"""Run detection model accuracy test with 1 GPU."""
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
......@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled."""
"""Run detection model accuracy test with 1 GPU and XLA enabled."""
self._setup()
params = self._params()
params['architecture']['use_bfloat16'] = False
......@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
......@@ -271,6 +275,88 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run detection model with SpineNet backbone accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
params['architecture']['multilevel_features'] = 'identity'
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
params['train']['checkpoint']['path'] = ''
FLAGS.model_dir = self._get_model_dir(
'real_benchmark_2x2_tpu_spinenet_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
class RetinanetBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Retinanet model."""
def __init__(self, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(
model='retinanet',
**kwargs)
class MaskRCNNBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Mask RCNN model."""
def __init__(self, **kwargs):
super(MaskRCNNBenchmarkReal, self).__init__(
model='mask_rcnn',
**kwargs)
class ShapeMaskBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for ShapeMask model."""
def __init__(self, **kwargs):
super(ShapeMaskBenchmarkReal, self).__init__(
model='shapemask',
**kwargs)
if __name__ == '__main__':
tf.test.main()
......@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS
......@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and
require the same FLAG setup.
"""
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None):
self._set_data_files(root_data_dir=root_data_dir)
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, root_data_dir=None, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if tpu_run:
root_data_dir = TPU_DATA_DIR
elif GPU_DATA_DIR is not None:
root_data_dir = GPU_DATA_DIR
root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de')
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_file_flags(self):
"""Sets the FLAGS for the data files."""
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
......@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 2048
FLAGS.train_steps = 1000
......@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096
FLAGS.train_steps = 100000
......@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
......@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
......@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
......@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
......@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
......@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
......@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
......@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
......@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu)
def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
def _set_df_common(self):
self._set_data_files(tpu_run=True)
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 6144
FLAGS.enable_checkpointing = False
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
......@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
......@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
......
......@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {}
......
......@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
......@@ -104,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -128,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT"
},
"source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:"
"You can get a pre-trained BERT encoder from [TensorFlow Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2):"
]
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -252,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -267,7 +267,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -290,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -313,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -336,7 +336,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -376,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -404,7 +404,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -446,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -469,7 +469,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -490,7 +490,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -514,7 +514,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -562,7 +562,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -587,7 +587,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -617,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -661,7 +661,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -691,7 +691,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -737,7 +737,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -769,7 +769,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -793,7 +793,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -816,7 +816,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -845,7 +845,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -870,7 +870,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -908,7 +908,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -943,7 +943,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -986,7 +986,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1023,7 +1023,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1055,7 +1055,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1071,7 +1071,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1096,7 +1096,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1110,7 +1110,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1176,7 +1176,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1201,7 +1201,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1240,7 +1240,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1273,7 +1273,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1306,7 +1306,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1351,7 +1351,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1379,7 +1379,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1406,17 +1406,44 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
"id": "GDWrHm0BGpbX"
},
"outputs": [],
"source": [
"# Note: 350MB download.\n",
"import tensorflow_hub as hub\n",
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n",
"import tensorflow_hub as hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
......@@ -1433,7 +1460,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1466,7 +1493,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1491,7 +1518,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1504,7 +1531,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1545,7 +1572,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1569,7 +1596,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1592,7 +1619,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1617,7 +1644,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1643,7 +1670,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1661,7 +1688,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1688,7 +1715,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1714,7 +1741,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1733,7 +1760,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1761,7 +1788,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1795,7 +1822,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Bp8t2AI8i7uP"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "rxPj2Lsni9O4"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6xS-9i5DrRvO"
},
"source": [
"# Customizing a Transformer Encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Mwb9uw1cDXsa"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iLrcV4IyrcGX"
},
"source": [
"## Learning objectives\n",
"\n",
"The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n",
"\n",
"The [TransformEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YYxdyoWgsl8t"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fEJSFutUsn_h"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "thsKZDjhswhR"
},
"outputs": [],
"source": [
"!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hpf7JPCVsqtv"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "my4dp-RMssQe"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from official.modeling import activations\n",
"from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vjDmVsFfs85n"
},
"source": [
"## Canonical BERT encoder\n",
"\n",
"Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `BertClassifier` for classification task."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Oav8sbgstWc-"
},
"outputs": [],
"source": [
"cfg = {\n",
" \"vocab_size\": 100,\n",
" \"hidden_size\": 32,\n",
" \"num_layers\": 3,\n",
" \"num_attention_heads\": 4,\n",
" \"intermediate_size\": 64,\n",
" \"activation\": activations.gelu,\n",
" \"dropout_rate\": 0.1,\n",
" \"attention_dropout_rate\": 0.1,\n",
" \"sequence_length\": 16,\n",
" \"type_vocab_size\": 2,\n",
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
"}\n",
"bert_encoder = modeling.networks.TransformerEncoder(**cfg)\n",
"\n",
"def build_classifier(bert_encoder):\n",
" return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n",
"\n",
"canonical_classifier_model = build_classifier(bert_encoder)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qe2UWI6_tsHo"
},
"source": [
"`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb). We skip the code that trains the model here.\n",
"\n",
"After training, we can apply the model to do prediction.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "csED2d-Yt5h6"
},
"outputs": [],
"source": [
"def predict(model):\n",
" batch_size = 3\n",
" np.random.seed(0)\n",
" word_ids = np.random.randint(\n",
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n",
" mask = np.random.randint(2, size=(batch_size, cfg[\"sequence_length\"]))\n",
" type_ids = np.random.randint(\n",
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n",
" print(model([word_ids, mask, type_ids], training=False))\n",
"\n",
"predict(canonical_classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PzKStEK9t_Pb"
},
"source": [
"## Customize BERT encoder\n",
"\n",
"One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rmwQfhj6fmKz"
},
"source": [
"We provide easy ways to customize each of those components via (1)\n",
"[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xsMgEVHAui11"
},
"source": [
"### Use EncoderScaffold\n",
"\n",
"`EncoderScaffold` allows users to provide a custom embedding subnetwork\n",
" (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-JBabpa2AOz8"
},
"source": [
"#### Without Customization\n",
"\n",
"Without any customization, `EncoderScaffold` behaves the same the canonical `TransformerEncoder`.\n",
"\n",
"As shown in the following example, `EncoderScaffold` can load `TransformerEncoder`'s weights and output the same values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ktNzKuVByZQf"
},
"outputs": [],
"source": [
"default_hidden_cfg = dict(\n",
" num_attention_heads=cfg[\"num_attention_heads\"],\n",
" intermediate_size=cfg[\"intermediate_size\"],\n",
" intermediate_activation=activations.gelu,\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n",
" kernel_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n",
"default_embedding_cfg = dict(\n",
" vocab_size=cfg[\"vocab_size\"],\n",
" type_vocab_size=cfg[\"type_vocab_size\"],\n",
" hidden_size=cfg[\"hidden_size\"],\n",
" seq_length=cfg[\"sequence_length\"],\n",
" initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" max_seq_length=cfg[\"sequence_length\"],\n",
")\n",
"default_kwargs = dict(\n",
" hidden_cfg=default_hidden_cfg,\n",
" embedding_cfg=default_embedding_cfg,\n",
" num_hidden_instances=cfg[\"num_layers\"],\n",
" pooled_output_dim=cfg[\"hidden_size\"],\n",
" return_all_layer_outputs=True,\n",
" pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n",
"encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n",
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
"classifier_model_from_encoder_scaffold.set_weights(\n",
" canonical_classifier_model.get_weights())\n",
"predict(classifier_model_from_encoder_scaffold)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sMaUmLyIuwcs"
},
"source": [
"#### Customize Embedding\n",
"\n",
"Next, we show how to use a customized embedding network.\n",
"\n",
"We firstly build an embedding network that will replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LTinnaG6vcsw"
},
"outputs": [],
"source": [
"word_ids = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
"mask = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
"embedding_layer = modeling.layers.OnDeviceEmbedding(\n",
" vocab_size=cfg['vocab_size'],\n",
" embedding_width=cfg['hidden_size'],\n",
" initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
" name=\"word_embeddings\")\n",
"word_embeddings = embedding_layer(word_ids)\n",
"attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n",
"new_embedding_network = tf.keras.Model([word_ids, mask],\n",
" [word_embeddings, attention_mask])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HN7_yu-6O3qI"
},
"source": [
"Inspecting `new_embedding_network`, we can see it takes two inputs:\n",
"`input_word_ids` and `input_mask`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fO9zKFE4OpHp"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9cOaGQHLv12W"
},
"source": [
"We then can build a new encoder using the above `new_embedding_network`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "mtFDMNf2vIl9"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use new embedding network.\n",
"kwargs['embedding_cls'] = new_embedding_network\n",
"kwargs['embedding_data'] = embedding_layer.embeddings\n",
"\n",
"encoder_with_customized_embedding = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_customized_embedding)\n",
"# ... Train the model ...\n",
"print(classifier_model.inputs)\n",
"\n",
"# Assert that there are only two inputs.\n",
"assert len(classifier_model.inputs) == 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Z73ZQDtmwg9K"
},
"source": [
"#### Customized Transformer\n",
"\n",
"User can also override the [hidden_cls](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py#L103) argument in `EncoderScaffold`'s constructor to employ a customized Transformer layer.\n",
"\n",
"See [ReZeroTransformer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n",
"\n",
"Following is an example of using `ReZeroTransformer`:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uAIarLZgw6pA"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use ReZeroTransformer.\n",
"kwargs['hidden_cls'] = modeling.layers.ReZeroTransformer\n",
"\n",
"encoder_with_rezero_transformer = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_rezero_transformer)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6PMHFdvnxvR0"
},
"source": [
"### Use [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)\n",
"\n",
"The above method of customizing `Transformer` requires rewriting the whole `Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py) can be used.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "D6FejlgwyAy_"
},
"source": [
"#### Customize Attention Layer\n",
"\n",
"User can also override the [attention_cls](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py#L45) argument in `TransformerScaffold`'s constructor to employ a customized Attention layer.\n",
"\n",
"See [TalkingHeadsAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n",
"\n",
"Following is an example of using [TalkingHeadsAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nFrSMrZuyNeQ"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['attention_cls'] = modeling.layers.TalkingHeadsAttention\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = modeling.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kuEJcTyByVvI"
},
"source": [
"#### Customize Feedforward Layer\n",
"\n",
"Similiarly, one could also customize the feedforward layer.\n",
"\n",
"See [GatedFeedforward](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n",
"\n",
"Following is an example of using [GatedFeedforward](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XAbKy_l4y_-i"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['feedforward_cls'] = modeling.layers.GatedFeedforward\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = modeling.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder_with_gated_feedforward = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_gated_feedforward)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `gate` from GatedFeedforward exists.\n",
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "a_8NWUhkzeAq"
},
"source": [
"### Build a new Encoder using building blocks from KerasBERT.\n",
"\n",
"Finally, you could also build a new encoder using building blocks in the modeling library.\n",
"\n",
"See [AlbertTransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_transformer_encoder.py) as an example:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xsiA3RzUzmUM"
},
"outputs": [],
"source": [
"albert_encoder = modeling.networks.AlbertTransformerEncoder(**cfg)\n",
"classifier_model = build_classifier(albert_encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MeidDfhlHKSO"
},
"source": [
"Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Uv_juT22HERW"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "80xnUmoI7fBX"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "8nvTnfs6Q692"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WmfcMK5P5C1G"
},
"source": [
"# Introduction to the TensorFlow Models NLP library"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cH-oJ8R6AHMK"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0H_EFIhq4-MJ"
},
"source": [
"## Learning objectives\n",
"\n",
"In this Colab notebook, you will learn how to build transformer-based models for common NLP tasks including pretraining, span labelling and classification using the building blocks from [NLP modeling library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2N97-dps_nUk"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "459ygAVl_rg0"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Y-qGkdh6_sZc"
},
"outputs": [],
"source": [
"!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "e4huSSwyAG_5"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jqYXqtjBAJd9"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "djBQWjvy-60Y"
},
"source": [
"## BERT pretraining model\n",
"\n",
"BERT ([Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)) introduced the method of pre-training language representations on a large text corpus and then using that model for downstream NLP tasks.\n",
"\n",
"In this section, we will learn how to build a model to pretrain BERT on the masked language modeling task and next sentence prediction task. For simplicity, we only show the minimum example and use dummy data."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MKuHVlsCHmiq"
},
"source": [
"### Build a `BertPretrainer` model wrapping `TransformerEncoder`\n",
"\n",
"The [TransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/transformer_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n",
"\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "EXkcXz-9BwB3"
},
"outputs": [],
"source": [
"# Build a small transformer network.\n",
"vocab_size = 100\n",
"sequence_length = 16\n",
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0NH5irV5KTMS"
},
"source": [
"Inspecting the encoder, we see it contains few embedding layers, stacked `Transformer` layers and are connected to three input layers:\n",
"\n",
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lZNoZkBrIoff"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "o7eFOZXiIl-b"
},
"outputs": [],
"source": [
"# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "d5h5HT7gNHx_"
},
"source": [
"Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `Classification` heads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2tcNfm03IBF7"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "F2oHrXGUIS0M"
},
"outputs": [],
"source": [
"# We can feed some dummy data to get masked language model and sentence output.\n",
"batch_size = 2\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"masked_lm_positions_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"\n",
"outputs = bert_pretrainer(\n",
" [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
"lm_output = outputs[\"masked_lm\"]\n",
"sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n",
"print(sentence_output)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bnx3UCHniCS5"
},
"source": [
"### Compute loss\n",
"Next, we can use `lm_output` and `sentence_output` to compute `loss`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "k30H4Q86f52x"
},
"outputs": [],
"source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
"\n",
"mlm_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=masked_lm_ids_data,\n",
" predictions=lm_output,\n",
" weights=masked_lm_weights_data)\n",
"sentence_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=next_sentence_labels_data,\n",
" predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wrmSs8GjHxVw"
},
"source": [
"With the loss, you can optimize the model.\n",
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "k8cQVFvBCV4s"
},
"source": [
"## Span labeling model\n",
"\n",
"Span labeling is the task to assign labels to a span of the text, for example, label a span of text as the answer of a given question.\n",
"\n",
"In this section, we will learn how to build a span labeling model. Again, we use dummy data for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xrLLEWpfknUW"
},
"source": [
"### Build a BertSpanLabeler wrapping TransformerEncoder\n",
"\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n",
"Note that `BertSpanLabeler` wraps a `TransformerEncoder`, the weights of which can be restored from the above pretraining model.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "B941M4iUCejO"
},
"outputs": [],
"source": [
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QpB9pgj4PpMg"
},
"source": [
"Inspecting the `bert_span_labeler`, we see it wraps the encoder with additional `SpanLabeling` that outputs `start_position` and `end_postion`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RbqRNJCLJu4H"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fUf1vRxZJwio"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n",
"print(end_logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WqhgQaN1lt-G"
},
"source": [
"### Compute loss\n",
"With `start_logits` and `end_logits`, we can compute loss:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "waqs6azNl3Nn"
},
"outputs": [],
"source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"\n",
"start_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" start_positions, start_logits, from_logits=True)\n",
"end_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" end_positions, end_logits, from_logits=True)\n",
"\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Zdf03YtZmd_d"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_squad.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_squad.py) for the full example."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0A1XnGSTChg9"
},
"source": [
"## Classification model\n",
"\n",
"In the last section, we show how to build a text classification model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MSK8OpZgnQa9"
},
"source": [
"### Build a BertClassifier model wrapping TransformerEncoder\n",
"\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cXXCsffkCphk"
},
"outputs": [],
"source": [
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n",
" network, num_classes=num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8tZKueKYP4bB"
},
"source": [
"Inspecting the `bert_classifier`, we see it wraps the `encoder` with additional `Classification` head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "snlutm9ZJgEZ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "yyHPHsqBJkCz"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "w--a2mg4nzKm"
},
"source": [
"### Compute loss\n",
"\n",
"With `logits`, we can compute `loss`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9X0S1DoFn_5Q"
},
"outputs": [],
"source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n",
"loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mzBqOylZo3og"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
......@@ -18,11 +18,11 @@ import abc
import functools
from typing import Any, Callable, Optional
from absl import logging
import six
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
......@@ -37,17 +37,29 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params: cfg.TaskConfig):
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self._task_config = params
self._logging_dir = logging_dir
@property
def task_config(self) -> cfg.TaskConfig:
return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
......@@ -55,11 +67,23 @@ class Task(tf.Module):
Args:
model: The keras.Model built or used by this task.
"""
pass
ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
"""Creates model architecture.
Returns:
A model instance.
......@@ -107,6 +131,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args:
params: hyperparams to create input pipelines.
......@@ -122,7 +147,7 @@ class Task(tf.Module):
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
......@@ -172,6 +197,8 @@ class Task(tf.Module):
metrics=None):
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
......@@ -217,7 +244,9 @@ class Task(tf.Module):
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
"""Validation step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
......@@ -244,52 +273,24 @@ class Task(tf.Module):
return logs
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
_REGISTERED_TASK_CLS = {}
"""Performs the forward step.
# TODO(b/158268740): Move these outside the base class file.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
With distribution strategies, this method runs on devices.
This decorator supports registration of tasks as follows:
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
Returns:
Model outputs.
"""
return model(inputs, training=False)
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step."""
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps."""
return {}
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
# Copyright 2016 Google Inc. All Rights Reserved.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,30 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A binary to train Inception on the flowers data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Experiment factory methods."""
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
import tensorflow as tf
_REGISTERED_CONFIGS = {}
from inception import inception_train
from inception.flowers_data import FlowersData
FLAGS = tf.app.flags.FLAGS
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def main(_):
dataset = FlowersData(subset=FLAGS.subset)
assert dataset.data_files()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
inception_train.train(dataset)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
if __name__ == '__main__':
tf.app.run()
def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
return get_exp_config_creater(exp_name)()
......@@ -32,8 +32,9 @@ class InputReader:
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
dataset_transform_fn: Optional[Callable[[tf.data.Dataset],
tf.data.Dataset]] = None,
transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance.
......@@ -48,9 +49,12 @@ class InputReader:
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be
executed after parser_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be
executed after `parser_fn` to transform and batch the dataset; if None,
after `parser_fn` is executed, the dataset will be batched into
per-replica batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
......@@ -101,7 +105,7 @@ class InputReader:
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
self._parser_fn = parser_fn
self._dataset_transform_fn = dataset_transform_fn
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
def _read_sharded_files(
......@@ -171,6 +175,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset
@property
......@@ -211,13 +218,13 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._dataset_transform_fn is not None:
dataset = self._dataset_transform_fn(dataset)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to register and access all registered tasks."""
from official.utils import registry
_REGISTERED_TASK_CLS = {}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config."""
return get_task_cls(task_config.__class__)(task_config, **kwargs)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
......@@ -14,12 +14,6 @@
# ==============================================================================
"""Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
......@@ -35,6 +29,4 @@ def gelu(x):
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
return tf.keras.activations.gelu(x, approximate=True)
......@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config
if k in cls.__annotations__:
# Directly Config subtype.
type_annotation = cls.__annotations__[k]
type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k]
subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else:
# Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None))
......
......@@ -14,13 +14,13 @@
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Union
import dataclasses
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
from official.utils import registry
OptimizationConfig = optimization_config.OptimizationConfig
......@@ -111,6 +111,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
......@@ -123,8 +125,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
......@@ -172,25 +174,39 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
checkpoints, if set to None, continuous eval will wait indefinitely.
This is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
# Train/Eval routines.
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass
class TaskConfig(base_config.Config):
network: base_config.Config = None
init_checkpoint: str = ""
model: base_config.Config = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
......@@ -198,24 +214,7 @@ class TaskConfig(base_config.Config):
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
mode: str = "train" # train, eval, train_and_eval.
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 100
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment