Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
> :memo: A README.md template for releasing a paper code implementation to a GitHub repository. > :memo: A README.md template for releasing a paper code implementation to a GitHub repository.
> >
> * Template version: 1.0.2020.125 > * Template version: 1.0.2020.170
> * Please modify sections depending on needs. > * Please modify sections depending on needs.
# Model name, Paper title, or Project Name # Model name, Paper title, or Project Name
> :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN) > :memo: Add a badge for the ArXiv identifier of your paper (arXiv:YYMM.NNNNN)
[![Paper](http://img.shields.io/badge/paper-arXiv.YYMM.NNNNN-B3181B.svg)](https://arxiv.org/abs/...) [![Paper](http://img.shields.io/badge/Paper-arXiv.YYMM.NNNNN-B3181B?logo=arXiv)](https://arxiv.org/abs/...)
This repository is the official or unofficial implementation of the following paper. This repository is the official or unofficial implementation of the following paper.
...@@ -28,8 +28,8 @@ This repository is the official or unofficial implementation of the following pa ...@@ -28,8 +28,8 @@ This repository is the official or unofficial implementation of the following pa
> :memo: Provide maintainer information. > :memo: Provide maintainer information.
* Last name, First name ([@GitHub username](https://github.com/username)) * Full name ([@GitHub username](https://github.com/username))
* Last name, First name ([@GitHub username](https://github.com/username)) * Full name ([@GitHub username](https://github.com/username))
## Table of Contents ## Table of Contents
...@@ -37,8 +37,8 @@ This repository is the official or unofficial implementation of the following pa ...@@ -37,8 +37,8 @@ This repository is the official or unofficial implementation of the following pa
## Requirements ## Requirements
[![TensorFlow 2.1](https://img.shields.io/badge/tensorflow-2.1-brightgreen)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0) [![TensorFlow 2.1](https://img.shields.io/badge/TensorFlow-2.1-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0)
[![Python 3.6](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/downloads/release/python-360/) [![Python 3.6](https://img.shields.io/badge/Python-3.6-3776AB)](https://www.python.org/downloads/release/python-360/)
> :memo: Provide details of the software required. > :memo: Provide details of the software required.
> >
...@@ -54,6 +54,8 @@ pip install -r requirements.txt ...@@ -54,6 +54,8 @@ pip install -r requirements.txt
## Results ## Results
[![TensorFlow Hub](https://img.shields.io/badge/TF%20Hub-Models-FF6F00?logo=tensorflow)](https://tfhub.dev/...)
> :memo: Provide a table with results. (e.g., accuracy, latency) > :memo: Provide a table with results. (e.g., accuracy, latency)
> >
> * Provide links to the pre-trained models (checkpoint, SavedModel files). > * Provide links to the pre-trained models (checkpoint, SavedModel files).
...@@ -104,6 +106,8 @@ python3 ... ...@@ -104,6 +106,8 @@ python3 ...
## License ## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
> :memo: Place your license text in a file named LICENSE in the root of the repository. > :memo: Place your license text in a file named LICENSE in the root of the repository.
> >
> * Include information about your license. > * Include information about your license.
......
...@@ -2,28 +2,34 @@ ...@@ -2,28 +2,34 @@
# Welcome to the Model Garden for TensorFlow # Welcome to the Model Garden for TensorFlow
The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users can take full advantage of TensorFlow for their research and product development. The TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art (SOTA) models and modeling solutions for TensorFlow users. We aim to demonstrate the best practices for modeling so that TensorFlow users
can take full advantage of TensorFlow for their research and product development.
| Directory | Description | | Directory | Description |
|-----------|-------------| |-----------|-------------|
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read | | [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers | | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 | | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](../../wiki/Announcements) ## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News | | Date | News |
|------|------| |------|------|
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released | July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection | June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 | June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released |
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection |
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) | | March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## Contributions ## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation) [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
If you want to contribute, please review the [contribution guidelines](../../wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
## License ## License
......
...@@ -6,13 +6,12 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -6,13 +6,12 @@ This repository provides a curated list of the GitHub repositories with machine
**Note**: Contributing companies or individuals are responsible for maintaining their repositories. **Note**: Contributing companies or individuals are responsible for maintaining their repositories.
## Models / Implementations ## Computer Vision
### Computer Vision ### Image Recognition
#### Image Recognition | Model | Paper | Features | Maintainer |
| Model | Reference (Paper) | Features | Maintainer | |-------|-------|----------|------------|
|-------|-------------------|----------|------------|
| [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) | | [DenseNet 169](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/densenet169) | [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993) | • FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [Inception V3](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv3) | [Rethinking the Inception Architecture<br/>for Computer Vision](https://arxiv.org/pdf/1512.00567.pdf) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [Inception V4](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/inceptionv4) | [Inception-v4, Inception-ResNet and the Impact<br/>of Residual Connections on Learning](https://arxiv.org/pdf/1602.07261) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
...@@ -21,12 +20,21 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -21,12 +20,21 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) | | [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
#### Segmentation ### Object Detection
| Model | Reference (Paper) | &nbsp; &nbsp; &nbsp; Features &nbsp; &nbsp; &nbsp; | Maintainer |
|-------|-------------------|----------|------------| | Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Segmentation
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) | | [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/MaskRCNN) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
| [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) | | [U-Net Medical Image Segmentation](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow2/Segmentation/UNet_Medical) | [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) | • Automatic Mixed Precision<br/>• Multi-GPU training support with Horovod<br/>• TensorRT | [NVIDIA](https://github.com/NVIDIA) |
## Contributions ## Contributions
If you want to contribute, please review the [contribution guidelines](../../../wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
...@@ -17,11 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build. ...@@ -17,11 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models. The team is actively developing new models.
In the near future, we will add: In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models.
More members in Transformer family * State-of-the-art image classification models.
* Start-of-the-art image classification models: * State-of-the-art objection detection and instance segmentation models.
EfficientNet, MnasNet, and variants
* A set of excellent objection detection models.
## Table of Contents ## Table of Contents
...@@ -43,6 +41,7 @@ In the near future, we will add: ...@@ -43,6 +41,7 @@ In the near future, we will add:
|-------|-------------------| |-------|-------------------|
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) | | [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) | | [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
#### Object Detection and Segmentation #### Object Detection and Segmentation
...@@ -50,6 +49,8 @@ In the near future, we will add: ...@@ -50,6 +49,8 @@ In the near future, we will add:
|-------|-------------------| |-------|-------------------|
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
......
...@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(summary_path=summary_path, self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True) report_accuracy=True)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 2x2 TPU for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 2x2 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden') @owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self): def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 4x4 TPU for 10000 steps.""" """Test bert pretraining with 4x4 TPU for 10000 steps."""
...@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark( self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False) summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 4x4 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden') @owner_utils.Owner('tf-model-garden')
def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self): def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 8x8 TPU for 10000 steps.""" """Test bert pretraining with 8x8 TPU for 10000 steps."""
......
...@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
return os.path.join(self.output_dir, folder_name) return os.path.join(self.output_dir, folder_name)
class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 (classifier_trainer) benchmarks.""" """Classifier Trainer benchmarks."""
def __init__(self, output_dir=None, default_flags=None, def __init__(self, model, output_dir=None, default_flags=None,
tpu=None, dataset_builder='records', train_epochs=1, tpu=None, dataset_builder='records', train_epochs=1,
train_steps=110, data_dir=None): train_steps=110, data_dir=None):
flag_methods = [classifier_trainer.define_classifier_flags] flag_methods = [classifier_trainer.define_classifier_flags]
self.model = model
self.dataset_builder = dataset_builder self.dataset_builder = dataset_builder
self.train_epochs = train_epochs self.train_epochs = train_epochs
self.train_steps = train_steps self.train_steps = train_steps
self.data_dir = data_dir self.data_dir = data_dir
super(Resnet50KerasClassifierBenchmarkBase, self).__init__( super(KerasClassifierBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags, default_flags=default_flags,
...@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
dataset_num_private_threads: Optional[int] = None, dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None): loss_scale: Optional[str] = None):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
FLAGS.model_type = 'resnet' FLAGS.model_type = self.model
FLAGS.dataset = 'imagenet' FLAGS.dataset = 'imagenet'
FLAGS.mode = 'train_and_eval' FLAGS.mode = 'train_and_eval'
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
...@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
# input skip_steps. # input skip_steps.
warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark( super(KerasClassifierBenchmarkBase, self)._report_benchmark(
stats, stats,
wall_time_sec, wall_time_sec,
total_batch_size=total_batch_size, total_batch_size=total_batch_size,
...@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='mirrored', distribution_strategy='mirrored',
per_replica_batch_size=256, per_replica_batch_size=256,
gpu_thread_mode='gpu_private', gpu_thread_mode='gpu_private',
dataset_num_private_threads=48, dataset_num_private_threads=48)
steps=310)
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self): def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
"""Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16.""" """Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
...@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_2x2_tpu_bf16_mlir(self):
"""Test Keras model with 2x2 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_4x4_tpu_bf16_mlir(self):
"""Test Keras model with 4x4 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=32,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_8x8_tpu_bf16(self): def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16.""" """Test Keras model with 8x8 TPU, bf16."""
self._setup() self._setup()
...@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
per_replica_batch_size=64) per_replica_batch_size=64)
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object( super(KerasClassifierBenchmarkBase, self).fill_report_object(
stats, stats,
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
...@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase): class Resnet50KerasBenchmarkSynth(KerasClassifierBenchmarkBase):
"""Resnet50 synthetic benchmark tests.""" """Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
...@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase): ...@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__( super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu, model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='synthetic', train_epochs=1, train_steps=110) dataset_builder='synthetic', train_epochs=1, train_steps=110)
class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase): class Resnet50KerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""Resnet50 real data benchmark tests.""" """Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
...@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase): ...@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__( super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu, model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='records', train_epochs=1, train_steps=110, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir) data_dir=data_dir)
class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""EfficientNet real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
data_dir = os.path.join(root_data_dir, 'imagenet')
def_flags = {}
def_flags['log_steps'] = 10
super(EfficientNetKerasBenchmarkReal, self).__init__(
model='efficientnet', output_dir=output_dir, default_flags=def_flags,
tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase): class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
"""Resnet50 real data (stored in remote storage) benchmark tests.""" """Resnet50 real data (stored in remote storage) benchmark tests."""
......
...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS ...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark): class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing.""" """Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None): def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
**kwargs):
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__( super(CtlBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=self.default_flags, default_flags=self.default_flags,
flag_methods=self.flag_methods) flag_methods=self.flag_methods,
**kwargs)
def _report_benchmark(self, def _report_benchmark(self,
stats, stats,
...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class Resnet50CtlBenchmarkBase(CtlBenchmark): class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None, **kwargs):
flag_methods = [common.define_keras_flags] flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__( super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags,
**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
...@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False FLAGS.enable_checkpoint_and_export = False
FLAGS.data_dir = 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def benchmark_2x2_tpu_bf16(self): def benchmark_2x2_tpu_bf16(self):
self._setup() self._setup()
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 1024 FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_bf16_mlir(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16_mlir')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self): def benchmark_4x4_tpu_bf16(self):
...@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler') @owner_utils.Owner('tf-graph-compiler')
...@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16_mlir')
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase): ...@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__( super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): ...@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__( super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi ...@@ -44,11 +44,11 @@ RESNET_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/retinanet/resnet50-checkpoi
# pylint: enable=line-too-long # pylint: enable=line-too-long
class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): class BenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
"""Base class to hold methods common to test classes.""" """Base class to hold methods common to test classes."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DetectionBenchmarkBase, self).__init__(**kwargs) super(BenchmarkBase, self).__init__(**kwargs)
self.timer_callback = None self.timer_callback = None
def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap, def _report_benchmark(self, stats, start_time_sec, wall_time_sec, min_ap,
...@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark): ...@@ -99,7 +99,7 @@ class DetectionBenchmarkBase(perfzero_benchmark.PerfZeroBenchmark):
extras={'flags': flags_str}) extras={'flags': flags_str})
class RetinanetBenchmarkBase(DetectionBenchmarkBase): class DetectionBenchmarkBase(BenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -107,7 +107,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
self.eval_data_path = COCO_EVAL_DATA self.eval_data_path = COCO_EVAL_DATA
self.eval_json_path = COCO_EVAL_JSON self.eval_json_path = COCO_EVAL_JSON
self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH self.resnet_checkpoint_path = RESNET_CHECKPOINT_PATH
super(RetinanetBenchmarkBase, self).__init__(**kwargs) super(DetectionBenchmarkBase, self).__init__(**kwargs)
def _run_detection_main(self): def _run_detection_main(self):
"""Starts detection job.""" """Starts detection job."""
...@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase): ...@@ -118,7 +118,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
return detection.run() return detection.run()
class RetinanetAccuracy(RetinanetBenchmarkBase): class DetectionAccuracy(DetectionBenchmarkBase):
"""Accuracy test for RetinaNet model. """Accuracy test for RetinaNet model.
Tests RetinaNet detection task model accuracy. The naming Tests RetinaNet detection task model accuracy. The naming
...@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -126,6 +126,10 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
`benchmark_(number of gpus)_gpu_(dataset type)` format. `benchmark_(number of gpus)_gpu_(dataset type)` format.
""" """
def __init__(self, model, **kwargs):
self.model = model
super(DetectionAccuracy, self).__init__(**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
params, params,
...@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -133,7 +137,7 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap=0.35, max_ap=0.35,
do_eval=True, do_eval=True,
warmup=1): warmup=1):
"""Starts RetinaNet accuracy benchmark test.""" """Starts Detection accuracy benchmark test."""
FLAGS.params_override = json.dumps(params) FLAGS.params_override = json.dumps(params)
# Need timer callback to measure performance # Need timer callback to measure performance
self.timer_callback = keras_utils.TimeHistory( self.timer_callback = keras_utils.TimeHistory(
...@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -156,8 +160,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
max_ap, warmup) max_ap, warmup)
def _setup(self): def _setup(self):
super(RetinanetAccuracy, self)._setup() super(DetectionAccuracy, self)._setup()
FLAGS.model = 'retinanet' FLAGS.model = self.model
def _params(self): def _params(self):
return { return {
...@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase): ...@@ -195,22 +199,22 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
self._run_and_report_benchmark(params) self._run_and_report_benchmark(params)
class RetinanetBenchmarkReal(RetinanetAccuracy): class DetectionBenchmarkReal(DetectionAccuracy):
"""Short benchmark performance tests for RetinaNet model. """Short benchmark performance tests for a detection model.
Tests RetinaNet performance in different GPU configurations. Tests detection performance in different accelerator configurations.
The naming convention of below test cases follow The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format. `benchmark_(number of gpus)_gpu` format.
""" """
def _setup(self): def _setup(self):
super(RetinanetBenchmarkReal, self)._setup() super(DetectionBenchmarkReal, self)._setup()
# Use negative value to avoid saving checkpoints. # Use negative value to avoid saving checkpoints.
FLAGS.save_checkpoint_freq = -1 FLAGS.save_checkpoint_freq = -1
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_8_gpu_coco(self): def benchmark_8_gpu_coco(self):
"""Run RetinaNet model accuracy test with 8 GPUs.""" """Run detection model accuracy test with 8 GPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -230,7 +234,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_1_gpu_coco(self): def benchmark_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU.""" """Run detection model accuracy test with 1 GPU."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -245,7 +249,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_xla_1_gpu_coco(self): def benchmark_xla_1_gpu_coco(self):
"""Run RetinaNet model accuracy test with 1 GPU and XLA enabled.""" """Run detection model accuracy test with 1 GPU and XLA enabled."""
self._setup() self._setup()
params = self._params() params = self._params()
params['architecture']['use_bfloat16'] = False params['architecture']['use_bfloat16'] = False
...@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -261,7 +265,7 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_coco(self): def benchmark_2x2_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs.""" """Run detection model accuracy test with 4 TPUs."""
self._setup() self._setup()
params = self._params() params = self._params()
params['train']['batch_size'] = 64 params['train']['batch_size'] = 64
...@@ -271,6 +275,88 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -271,6 +275,88 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu' FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run detection model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run detection model with SpineNet backbone accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
params['architecture']['multilevel_features'] = 'identity'
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
params['train']['checkpoint']['path'] = ''
FLAGS.model_dir = self._get_model_dir(
'real_benchmark_2x2_tpu_spinenet_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
class RetinanetBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Retinanet model."""
def __init__(self, **kwargs):
super(RetinanetBenchmarkReal, self).__init__(
model='retinanet',
**kwargs)
class MaskRCNNBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for Mask RCNN model."""
def __init__(self, **kwargs):
super(MaskRCNNBenchmarkReal, self).__init__(
model='mask_rcnn',
**kwargs)
class ShapeMaskBenchmarkReal(DetectionBenchmarkReal):
"""Short benchmark performance tests for ShapeMask model."""
def __init__(self, **kwargs):
super(ShapeMaskBenchmarkReal, self).__init__(
model='shapemask',
**kwargs)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc ...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and Code under test for the Transformer Keras models report the same data and
require the same FLAG setup. require the same FLAG setup.
""" """
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None): flag_methods=None, tpu=None):
self._set_data_files(root_data_dir=root_data_dir)
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, root_data_dir=None, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if tpu_run:
root_data_dir = TPU_DATA_DIR
elif GPU_DATA_DIR is not None:
root_data_dir = GPU_DATA_DIR
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME) TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir, self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME, TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768') 'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir, self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en') 'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir, self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
if default_flags is None: def _set_data_file_flags(self):
default_flags = {} """Sets the FLAGS for the data files."""
default_flags['data_dir'] = self.train_data_dir FLAGS.data_dir = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
super(TransformerBenchmark, self).__init__( FLAGS['bleu_source'].value = self.bleu_source
output_dir=output_dir, FLAGS['bleu_ref'].value = self.bleu_ref
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 FLAGS.batch_size = 2048
FLAGS.train_steps = 1000 FLAGS.train_steps = 1000
...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals. Iterations are not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals. not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072, root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu) tpu=tpu)
def benchmark_2x2_tpu(self): def _set_df_common(self):
"""Port of former snaggletooth transformer_big model on 2x2.""" self._set_data_files(tpu_run=True)
self._setup() FLAGS.data_dir = self.train_data_dir
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu') FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300 FLAGS.train_steps = 300
FLAGS.log_steps = 150 FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150 FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True FLAGS.static_batch = True
FLAGS.use_ctl = True FLAGS.use_ctl = True
FLAGS.batch_size = 6144 FLAGS.enable_checkpointing = False
FLAGS.max_length = 64 FLAGS.max_length = 64
FLAGS.decode_batch_size = 32 FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97 FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self): def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4.""" """Port of former GCP transformer_big model on 4x4."""
self._setup() self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self): def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled.""" """Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') self._set_df_common()
FLAGS.train_steps = 300 FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
......
...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark): ...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS) params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params) strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {} stats = {}
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {}, "colab": {},
...@@ -104,7 +104,7 @@ ...@@ -104,7 +104,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -204,12 +204,12 @@ ...@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT" "id": "9uFskufsR2LT"
}, },
"source": [ "source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:" "You can get a pre-trained BERT encoder from [TensorFlow Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2):"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -252,7 +252,7 @@ ...@@ -252,7 +252,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -267,7 +267,7 @@ ...@@ -267,7 +267,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -290,7 +290,7 @@ ...@@ -290,7 +290,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -313,7 +313,7 @@ ...@@ -313,7 +313,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -336,7 +336,7 @@ ...@@ -336,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -376,7 +376,7 @@ ...@@ -376,7 +376,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -404,7 +404,7 @@ ...@@ -404,7 +404,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -446,7 +446,7 @@ ...@@ -446,7 +446,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -469,7 +469,7 @@ ...@@ -469,7 +469,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -490,7 +490,7 @@ ...@@ -490,7 +490,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -514,7 +514,7 @@ ...@@ -514,7 +514,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -562,7 +562,7 @@ ...@@ -562,7 +562,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -587,7 +587,7 @@ ...@@ -587,7 +587,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -661,7 +661,7 @@ ...@@ -661,7 +661,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -691,7 +691,7 @@ ...@@ -691,7 +691,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -737,7 +737,7 @@ ...@@ -737,7 +737,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -769,7 +769,7 @@ ...@@ -769,7 +769,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -793,7 +793,7 @@ ...@@ -793,7 +793,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -816,7 +816,7 @@ ...@@ -816,7 +816,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -845,7 +845,7 @@ ...@@ -845,7 +845,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -870,7 +870,7 @@ ...@@ -870,7 +870,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -908,7 +908,7 @@ ...@@ -908,7 +908,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -943,7 +943,7 @@ ...@@ -943,7 +943,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -986,7 +986,7 @@ ...@@ -986,7 +986,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1023,7 +1023,7 @@ ...@@ -1023,7 +1023,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1055,7 +1055,7 @@ ...@@ -1055,7 +1055,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1071,7 +1071,7 @@ ...@@ -1071,7 +1071,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1096,7 +1096,7 @@ ...@@ -1096,7 +1096,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1110,7 +1110,7 @@ ...@@ -1110,7 +1110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1176,7 +1176,7 @@ ...@@ -1176,7 +1176,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1201,7 +1201,7 @@ ...@@ -1201,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1240,7 +1240,7 @@ ...@@ -1240,7 +1240,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1273,7 +1273,7 @@ ...@@ -1273,7 +1273,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1306,7 +1306,7 @@ ...@@ -1306,7 +1306,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1351,7 +1351,7 @@ ...@@ -1351,7 +1351,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1379,7 +1379,7 @@ ...@@ -1379,7 +1379,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1406,17 +1406,44 @@ ...@@ -1406,17 +1406,44 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "lo6479At4sP1" "id": "GDWrHm0BGpbX"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Note: 350MB download.\n", "# Note: 350MB download.\n",
"import tensorflow_hub as hub\n", "import tensorflow_hub as hub"
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n", ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n", "\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")" "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
] ]
...@@ -1433,7 +1460,7 @@ ...@@ -1433,7 +1460,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1466,7 +1493,7 @@ ...@@ -1466,7 +1493,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1491,7 +1518,7 @@ ...@@ -1491,7 +1518,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1504,7 +1531,7 @@ ...@@ -1504,7 +1531,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1545,7 +1572,7 @@ ...@@ -1545,7 +1572,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1569,7 +1596,7 @@ ...@@ -1569,7 +1596,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1592,7 +1619,7 @@ ...@@ -1592,7 +1619,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1617,7 +1644,7 @@ ...@@ -1617,7 +1644,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1643,7 +1670,7 @@ ...@@ -1643,7 +1670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1661,7 +1688,7 @@ ...@@ -1661,7 +1688,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1688,7 +1715,7 @@ ...@@ -1688,7 +1715,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1714,7 +1741,7 @@ ...@@ -1714,7 +1741,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1733,7 +1760,7 @@ ...@@ -1733,7 +1760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1761,7 +1788,7 @@ ...@@ -1761,7 +1788,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1795,7 +1822,7 @@ ...@@ -1795,7 +1822,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
......
This diff is collapsed.
This diff is collapsed.
...@@ -18,11 +18,11 @@ import abc ...@@ -18,11 +18,11 @@ import abc
import functools import functools
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from absl import logging
import six import six
import tensorflow as tf import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
...@@ -37,17 +37,29 @@ class Task(tf.Module): ...@@ -37,17 +37,29 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs. # Special keys in train/validate step returned logs.
loss = "loss" loss = "loss"
def __init__(self, params: cfg.TaskConfig): def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self._task_config = params self._task_config = params
self._logging_dir = logging_dir
@property @property
def task_config(self) -> cfg.TaskConfig: def task_config(self) -> cfg.TaskConfig:
return self._task_config return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model. This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir. checkpoint, saved under a directory other than the model_dir.
...@@ -55,11 +67,23 @@ class Task(tf.Module): ...@@ -55,11 +67,23 @@ class Task(tf.Module):
Args: Args:
model: The keras.Model built or used by this task. model: The keras.Model built or used by this task.
""" """
pass ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
@abc.abstractmethod @abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates model architecture.
Returns: Returns:
A model instance. A model instance.
...@@ -107,6 +131,7 @@ class Task(tf.Module): ...@@ -107,6 +131,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions. """Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size. Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args: Args:
params: hyperparams to create input pipelines. params: hyperparams to create input pipelines.
...@@ -122,7 +147,7 @@ class Task(tf.Module): ...@@ -122,7 +147,7 @@ class Task(tf.Module):
Args: Args:
labels: optional label tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
...@@ -172,6 +197,8 @@ class Task(tf.Module): ...@@ -172,6 +197,8 @@ class Task(tf.Module):
metrics=None): metrics=None):
"""Does forward and backward. """Does forward and backward.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
model: the model, forward pass definition. model: the model, forward pass definition.
...@@ -217,7 +244,9 @@ class Task(tf.Module): ...@@ -217,7 +244,9 @@ class Task(tf.Module):
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validation step.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
...@@ -244,52 +273,24 @@ class Task(tf.Module): ...@@ -244,52 +273,24 @@ class Task(tf.Module):
return logs return logs
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step.
return model(inputs, training=False)
_REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file. With distribution strategies, this method runs on devices.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows: Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
``` Returns:
@dataclasses.dataclass Model outputs.
class MyTaskConfig(TaskConfig): """
# Add fields here. return model(inputs, training=False)
pass
@register_task_cls(MyTaskConfig) def aggregate_logs(self, state, step_logs):
class MyTask(Task): """Optional aggregation over logs returned from a validation step."""
# Inherits def __init__(self, task_config).
pass pass
my_task_config = MyTaskConfig() def reduce_aggregated_logs(self, aggregated_logs):
my_task = get_task(my_task_config) # Returns MyTask(my_task_config). """Optional reduce of aggregated logs over validation steps."""
``` return {}
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
# Copyright 2016 Google Inc. All Rights Reserved. # Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,30 +13,25 @@ ...@@ -12,30 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A binary to train Inception on the flowers data set. """Experiment factory methods."""
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
import tensorflow as tf _REGISTERED_CONFIGS = {}
from inception import inception_train
from inception.flowers_data import FlowersData
FLAGS = tf.app.flags.FLAGS def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def main(_): def get_exp_config_creater(exp_name: str):
dataset = FlowersData(subset=FLAGS.subset) """Looks up ExperimentConfig factory methods."""
assert dataset.data_files() exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
if tf.gfile.Exists(FLAGS.train_dir): return exp_creater
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
inception_train.train(dataset)
if __name__ == '__main__': def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
tf.app.run() return get_exp_config_creater(exp_name)()
...@@ -32,8 +32,9 @@ class InputReader: ...@@ -32,8 +32,9 @@ class InputReader:
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
dataset_transform_fn: Optional[Callable[[tf.data.Dataset], transform_and_batch_fn: Optional[Callable[
tf.data.Dataset]] = None, [tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None): postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance. """Initializes an InputReader instance.
...@@ -48,9 +49,12 @@ class InputReader: ...@@ -48,9 +49,12 @@ class InputReader:
parser_fn: An optional `callable` that takes the decoded raw tensors dict parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn. model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
executed after parser_fn. input, and returns a `tf.data.Dataset` object. It will be
executed after `parser_fn` to transform and batch the dataset; if None,
after `parser_fn` is executed, the dataset will be batched into
per-replica batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching. will be executed after batching.
""" """
...@@ -101,7 +105,7 @@ class InputReader: ...@@ -101,7 +105,7 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._dataset_transform_fn = dataset_transform_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
def _read_sharded_files( def _read_sharded_files(
...@@ -171,6 +175,9 @@ class InputReader: ...@@ -171,6 +175,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised, as_supervised=self._tfds_as_supervised,
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
@property @property
...@@ -211,13 +218,13 @@ class InputReader: ...@@ -211,13 +218,13 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn) dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = maybe_map_fn(dataset, self._parser_fn)
if self._dataset_transform_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._dataset_transform_fn(dataset) dataset = self._transform_and_batch_fn(dataset, input_context)
else:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size self._global_batch_size) if input_context else self._global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn) dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to register and access all registered tasks."""
from official.utils import registry
_REGISTERED_TASK_CLS = {}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config."""
return get_task_cls(task_config.__class__)(task_config, **kwargs)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
...@@ -14,12 +14,6 @@ ...@@ -14,12 +14,6 @@
# ============================================================================== # ==============================================================================
"""Gaussian error linear unit.""" """Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf import tensorflow as tf
...@@ -35,6 +29,4 @@ def gelu(x): ...@@ -35,6 +29,4 @@ def gelu(x):
Returns: Returns:
`x` with the GELU activation applied. `x` with the GELU activation applied.
""" """
cdf = 0.5 * (1.0 + tf.tanh( return tf.keras.activations.gelu(x, approximate=True)
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict): ...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config subconfig_type = Config
if k in cls.__annotations__: if k in cls.__annotations__:
# Directly Config subtype. # Directly Config subtype.
type_annotation = cls.__annotations__[k] type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)): issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k] subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else: else:
# Check if the field is a sequence of subtypes. # Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None)) field_type = getattr(type_annotation, '__origin__', type(None))
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Union from typing import Optional, Union
import dataclasses import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config from official.modeling.optimization.configs import optimization_config
from official.utils import registry
OptimizationConfig = optimization_config.OptimizationConfig OptimizationConfig = optimization_config.OptimizationConfig
...@@ -111,6 +111,8 @@ class RuntimeConfig(base_config.Config): ...@@ -111,6 +111,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly. run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance. persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
""" """
distribution_strategy: str = "mirrored" distribution_strategy: str = "mirrored"
enable_xla: bool = False enable_xla: bool = False
...@@ -123,8 +125,8 @@ class RuntimeConfig(base_config.Config): ...@@ -123,8 +125,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
...@@ -172,25 +174,39 @@ class TrainerConfig(base_config.Config): ...@@ -172,25 +174,39 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval. eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop. steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary. summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints. checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep. max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely. checkpoints, if set to None, continuous eval will wait indefinitely.
This is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
""" """
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
train_tf_while_loop: bool = True train_tf_while_loop: bool = True
train_tf_function: bool = True train_tf_function: bool = True
eval_tf_function: bool = True eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000 steps_per_loop: int = 1000
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None continuous_eval_timeout: Optional[int] = None
# Train/Eval routines.
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
network: base_config.Config = None init_checkpoint: str = ""
model: base_config.Config = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
...@@ -198,24 +214,7 @@ class TaskConfig(base_config.Config): ...@@ -198,24 +214,7 @@ class TaskConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class ExperimentConfig(base_config.Config): class ExperimentConfig(base_config.Config):
"""Top-level configuration.""" """Top-level configuration."""
mode: str = "train" # train, eval, train_and_eval.
task: TaskConfig = TaskConfig() task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig() trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig() runtime: RuntimeConfig = RuntimeConfig()
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 100
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config_creater(exp_name: str):
"""Looks up ExperimentConfig factory methods."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment