"examples/vscode:/vscode.git/clone" did not exist on "a9288b49c974278dcb6c929b07620018a3f5595e"
Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
...@@ -10,24 +10,21 @@ can take full advantage of TensorFlow for their research and product development ...@@ -10,24 +10,21 @@ can take full advantage of TensorFlow for their research and product development
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read | | [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers | | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 | | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements) ## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News | | Date | News |
|------|------| |------|------|
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released | July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released | June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released | June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection | May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 | May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released |
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection |
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) | | March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## [Milestones](https://github.com/tensorflow/models/milestones)
| Date | Milestone |
|------|-----------|
| July 8, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
## Contributions ## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation) [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
......
...@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) | | [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Segmentation ### Segmentation
| Model | Paper | Features | Maintainer | | Model | Paper | Features | Maintainer |
......
...@@ -17,11 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build. ...@@ -17,11 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models. The team is actively developing new models.
In the near future, we will add: In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models.
More members in Transformer family * State-of-the-art image classification models.
* Start-of-the-art image classification models: * State-of-the-art objection detection and instance segmentation models.
EfficientNet, MnasNet, and variants
* A set of excellent objection detection models.
## Table of Contents ## Table of Contents
...@@ -52,6 +50,7 @@ In the near future, we will add: ...@@ -52,6 +50,7 @@ In the near future, we will add:
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) | | [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
......
...@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(summary_path=summary_path, self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True) report_accuracy=True)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 2x2 TPU for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 2x2 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden') @owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self): def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 4x4 TPU for 10000 steps.""" """Test bert pretraining with 4x4 TPU for 10000 steps."""
...@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark( self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False) summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 4x4 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden') @owner_utils.Owner('tf-model-garden')
def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self): def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 8x8 TPU for 10000 steps.""" """Test bert pretraining with 8x8 TPU for 10000 steps."""
......
...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS ...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark): class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing.""" """Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None): def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
**kwargs):
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__( super(CtlBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=self.default_flags, default_flags=self.default_flags,
flag_methods=self.flag_methods) flag_methods=self.flag_methods,
**kwargs)
def _report_benchmark(self, def _report_benchmark(self,
stats, stats,
...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class Resnet50CtlBenchmarkBase(CtlBenchmark): class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None, **kwargs):
flag_methods = [common.define_keras_flags] flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__( super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags,
**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
...@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False FLAGS.enable_checkpoint_and_export = False
FLAGS.data_dir = 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def benchmark_2x2_tpu_bf16(self): def benchmark_2x2_tpu_bf16(self):
self._setup() self._setup()
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 1024 FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_bf16_mlir(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16_mlir')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self): def benchmark_4x4_tpu_bf16(self):
...@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler') @owner_utils.Owner('tf-graph-compiler')
...@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16_mlir')
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase): ...@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__( super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): ...@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__( super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -271,6 +271,61 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -271,6 +271,61 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu' FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
params['architecture']['multilevel_features'] = 'identity'
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
params['train']['checkpoint']['path'] = ''
FLAGS.model_dir = self._get_model_dir(
'real_benchmark_2x2_tpu_spinenet_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc ...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and Code under test for the Transformer Keras models report the same data and
require the same FLAG setup. require the same FLAG setup.
""" """
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None): flag_methods=None, tpu=None):
self._set_data_files(root_data_dir=root_data_dir)
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, root_data_dir=None, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if tpu_run:
root_data_dir = TPU_DATA_DIR
elif GPU_DATA_DIR is not None:
root_data_dir = GPU_DATA_DIR
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME) TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir, self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME, TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768') 'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir, self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en') 'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir, self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
if default_flags is None: def _set_data_file_flags(self):
default_flags = {} """Sets the FLAGS for the data files."""
default_flags['data_dir'] = self.train_data_dir FLAGS.data_dir = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
super(TransformerBenchmark, self).__init__( FLAGS['bleu_source'].value = self.bleu_source
output_dir=output_dir, FLAGS['bleu_ref'].value = self.bleu_ref
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 FLAGS.batch_size = 2048
FLAGS.train_steps = 1000 FLAGS.train_steps = 1000
...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals. Iterations are not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals. not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072, root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu) tpu=tpu)
def benchmark_2x2_tpu(self): def _set_df_common(self):
"""Port of former snaggletooth transformer_big model on 2x2.""" self._set_data_files(tpu_run=True)
self._setup() FLAGS.data_dir = self.train_data_dir
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu') FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300 FLAGS.train_steps = 300
FLAGS.log_steps = 150 FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150 FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True FLAGS.static_batch = True
FLAGS.use_ctl = True FLAGS.use_ctl = True
FLAGS.batch_size = 6144 FLAGS.enable_checkpointing = False
FLAGS.max_length = 64 FLAGS.max_length = 64
FLAGS.decode_batch_size = 32 FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97 FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self): def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4.""" """Port of former GCP transformer_big model on 4x4."""
self._setup() self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self): def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled.""" """Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') self._set_df_common()
FLAGS.train_steps = 300 FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
......
...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark): ...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS) params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params) strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {} stats = {}
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {}, "colab": {},
...@@ -104,7 +104,7 @@ ...@@ -104,7 +104,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -204,12 +204,12 @@ ...@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT" "id": "9uFskufsR2LT"
}, },
"source": [ "source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:" "You can get a pre-trained BERT encoder from [TensorFlow Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2):"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -252,7 +252,7 @@ ...@@ -252,7 +252,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -267,7 +267,7 @@ ...@@ -267,7 +267,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -290,7 +290,7 @@ ...@@ -290,7 +290,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -313,7 +313,7 @@ ...@@ -313,7 +313,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -336,7 +336,7 @@ ...@@ -336,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -376,7 +376,7 @@ ...@@ -376,7 +376,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -404,7 +404,7 @@ ...@@ -404,7 +404,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -446,7 +446,7 @@ ...@@ -446,7 +446,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -469,7 +469,7 @@ ...@@ -469,7 +469,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -490,7 +490,7 @@ ...@@ -490,7 +490,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -514,7 +514,7 @@ ...@@ -514,7 +514,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -562,7 +562,7 @@ ...@@ -562,7 +562,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -587,7 +587,7 @@ ...@@ -587,7 +587,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -661,7 +661,7 @@ ...@@ -661,7 +661,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -691,7 +691,7 @@ ...@@ -691,7 +691,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -737,7 +737,7 @@ ...@@ -737,7 +737,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -769,7 +769,7 @@ ...@@ -769,7 +769,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -793,7 +793,7 @@ ...@@ -793,7 +793,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -816,7 +816,7 @@ ...@@ -816,7 +816,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -845,7 +845,7 @@ ...@@ -845,7 +845,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -870,7 +870,7 @@ ...@@ -870,7 +870,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -908,7 +908,7 @@ ...@@ -908,7 +908,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -943,7 +943,7 @@ ...@@ -943,7 +943,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -986,7 +986,7 @@ ...@@ -986,7 +986,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1023,7 +1023,7 @@ ...@@ -1023,7 +1023,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1055,7 +1055,7 @@ ...@@ -1055,7 +1055,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1071,7 +1071,7 @@ ...@@ -1071,7 +1071,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1096,7 +1096,7 @@ ...@@ -1096,7 +1096,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1110,7 +1110,7 @@ ...@@ -1110,7 +1110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1176,7 +1176,7 @@ ...@@ -1176,7 +1176,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1201,7 +1201,7 @@ ...@@ -1201,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1240,7 +1240,7 @@ ...@@ -1240,7 +1240,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1273,7 +1273,7 @@ ...@@ -1273,7 +1273,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1306,7 +1306,7 @@ ...@@ -1306,7 +1306,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1351,7 +1351,7 @@ ...@@ -1351,7 +1351,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1379,7 +1379,7 @@ ...@@ -1379,7 +1379,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1406,17 +1406,44 @@ ...@@ -1406,17 +1406,44 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "lo6479At4sP1" "id": "GDWrHm0BGpbX"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Note: 350MB download.\n", "# Note: 350MB download.\n",
"import tensorflow_hub as hub\n", "import tensorflow_hub as hub"
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n", ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n", "\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")" "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
] ]
...@@ -1433,7 +1460,7 @@ ...@@ -1433,7 +1460,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1466,7 +1493,7 @@ ...@@ -1466,7 +1493,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1491,7 +1518,7 @@ ...@@ -1491,7 +1518,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1504,7 +1531,7 @@ ...@@ -1504,7 +1531,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1545,7 +1572,7 @@ ...@@ -1545,7 +1572,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1569,7 +1596,7 @@ ...@@ -1569,7 +1596,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1592,7 +1619,7 @@ ...@@ -1592,7 +1619,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1617,7 +1644,7 @@ ...@@ -1617,7 +1644,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1643,7 +1670,7 @@ ...@@ -1643,7 +1670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1661,7 +1688,7 @@ ...@@ -1661,7 +1688,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1688,7 +1715,7 @@ ...@@ -1688,7 +1715,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1714,7 +1741,7 @@ ...@@ -1714,7 +1741,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1733,7 +1760,7 @@ ...@@ -1733,7 +1760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1761,7 +1788,7 @@ ...@@ -1761,7 +1788,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1795,7 +1822,7 @@ ...@@ -1795,7 +1822,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Bp8t2AI8i7uP"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "rxPj2Lsni9O4"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6xS-9i5DrRvO"
},
"source": [
"# Customizing a Transformer Encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Mwb9uw1cDXsa"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iLrcV4IyrcGX"
},
"source": [
"## Learning objectives\n",
"\n",
"The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n",
"\n",
"The [TransformEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YYxdyoWgsl8t"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fEJSFutUsn_h"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "thsKZDjhswhR"
},
"outputs": [],
"source": [
"!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hpf7JPCVsqtv"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "my4dp-RMssQe"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from official.modeling import activations\n",
"from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vjDmVsFfs85n"
},
"source": [
"## Canonical BERT encoder\n",
"\n",
"Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `BertClassifier` for classification task."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Oav8sbgstWc-"
},
"outputs": [],
"source": [
"cfg = {\n",
" \"vocab_size\": 100,\n",
" \"hidden_size\": 32,\n",
" \"num_layers\": 3,\n",
" \"num_attention_heads\": 4,\n",
" \"intermediate_size\": 64,\n",
" \"activation\": activations.gelu,\n",
" \"dropout_rate\": 0.1,\n",
" \"attention_dropout_rate\": 0.1,\n",
" \"sequence_length\": 16,\n",
" \"type_vocab_size\": 2,\n",
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
"}\n",
"bert_encoder = modeling.networks.TransformerEncoder(**cfg)\n",
"\n",
"def build_classifier(bert_encoder):\n",
" return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n",
"\n",
"canonical_classifier_model = build_classifier(bert_encoder)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qe2UWI6_tsHo"
},
"source": [
"`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb). We skip the code that trains the model here.\n",
"\n",
"After training, we can apply the model to do prediction.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "csED2d-Yt5h6"
},
"outputs": [],
"source": [
"def predict(model):\n",
" batch_size = 3\n",
" np.random.seed(0)\n",
" word_ids = np.random.randint(\n",
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n",
" mask = np.random.randint(2, size=(batch_size, cfg[\"sequence_length\"]))\n",
" type_ids = np.random.randint(\n",
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n",
" print(model([word_ids, mask, type_ids], training=False))\n",
"\n",
"predict(canonical_classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PzKStEK9t_Pb"
},
"source": [
"## Customize BERT encoder\n",
"\n",
"One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rmwQfhj6fmKz"
},
"source": [
"We provide easy ways to customize each of those components via (1)\n",
"[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xsMgEVHAui11"
},
"source": [
"### Use EncoderScaffold\n",
"\n",
"`EncoderScaffold` allows users to provide a custom embedding subnetwork\n",
" (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-JBabpa2AOz8"
},
"source": [
"#### Without Customization\n",
"\n",
"Without any customization, `EncoderScaffold` behaves the same the canonical `TransformerEncoder`.\n",
"\n",
"As shown in the following example, `EncoderScaffold` can load `TransformerEncoder`'s weights and output the same values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ktNzKuVByZQf"
},
"outputs": [],
"source": [
"default_hidden_cfg = dict(\n",
" num_attention_heads=cfg[\"num_attention_heads\"],\n",
" intermediate_size=cfg[\"intermediate_size\"],\n",
" intermediate_activation=activations.gelu,\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n",
" kernel_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n",
"default_embedding_cfg = dict(\n",
" vocab_size=cfg[\"vocab_size\"],\n",
" type_vocab_size=cfg[\"type_vocab_size\"],\n",
" hidden_size=cfg[\"hidden_size\"],\n",
" seq_length=cfg[\"sequence_length\"],\n",
" initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" max_seq_length=cfg[\"sequence_length\"],\n",
")\n",
"default_kwargs = dict(\n",
" hidden_cfg=default_hidden_cfg,\n",
" embedding_cfg=default_embedding_cfg,\n",
" num_hidden_instances=cfg[\"num_layers\"],\n",
" pooled_output_dim=cfg[\"hidden_size\"],\n",
" return_all_layer_outputs=True,\n",
" pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n",
"encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n",
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
"classifier_model_from_encoder_scaffold.set_weights(\n",
" canonical_classifier_model.get_weights())\n",
"predict(classifier_model_from_encoder_scaffold)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sMaUmLyIuwcs"
},
"source": [
"#### Customize Embedding\n",
"\n",
"Next, we show how to use a customized embedding network.\n",
"\n",
"We firstly build an embedding network that will replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LTinnaG6vcsw"
},
"outputs": [],
"source": [
"word_ids = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
"mask = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
"embedding_layer = modeling.layers.OnDeviceEmbedding(\n",
" vocab_size=cfg['vocab_size'],\n",
" embedding_width=cfg['hidden_size'],\n",
" initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
" name=\"word_embeddings\")\n",
"word_embeddings = embedding_layer(word_ids)\n",
"attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n",
"new_embedding_network = tf.keras.Model([word_ids, mask],\n",
" [word_embeddings, attention_mask])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HN7_yu-6O3qI"
},
"source": [
"Inspecting `new_embedding_network`, we can see it takes two inputs:\n",
"`input_word_ids` and `input_mask`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fO9zKFE4OpHp"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9cOaGQHLv12W"
},
"source": [
"We then can build a new encoder using the above `new_embedding_network`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "mtFDMNf2vIl9"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use new embedding network.\n",
"kwargs['embedding_cls'] = new_embedding_network\n",
"kwargs['embedding_data'] = embedding_layer.embeddings\n",
"\n",
"encoder_with_customized_embedding = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_customized_embedding)\n",
"# ... Train the model ...\n",
"print(classifier_model.inputs)\n",
"\n",
"# Assert that there are only two inputs.\n",
"assert len(classifier_model.inputs) == 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Z73ZQDtmwg9K"
},
"source": [
"#### Customized Transformer\n",
"\n",
"User can also override the [hidden_cls](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py#L103) argument in `EncoderScaffold`'s constructor to employ a customized Transformer layer.\n",
"\n",
"See [ReZeroTransformer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n",
"\n",
"Following is an example of using `ReZeroTransformer`:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uAIarLZgw6pA"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use ReZeroTransformer.\n",
"kwargs['hidden_cls'] = modeling.layers.ReZeroTransformer\n",
"\n",
"encoder_with_rezero_transformer = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_rezero_transformer)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6PMHFdvnxvR0"
},
"source": [
"### Use [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)\n",
"\n",
"The above method of customizing `Transformer` requires rewriting the whole `Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py) can be used.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "D6FejlgwyAy_"
},
"source": [
"#### Customize Attention Layer\n",
"\n",
"User can also override the [attention_cls](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py#L45) argument in `TransformerScaffold`'s constructor to employ a customized Attention layer.\n",
"\n",
"See [TalkingHeadsAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n",
"\n",
"Following is an example of using [TalkingHeadsAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nFrSMrZuyNeQ"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['attention_cls'] = modeling.layers.TalkingHeadsAttention\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = modeling.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kuEJcTyByVvI"
},
"source": [
"#### Customize Feedforward Layer\n",
"\n",
"Similiarly, one could also customize the feedforward layer.\n",
"\n",
"See [GatedFeedforward](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n",
"\n",
"Following is an example of using [GatedFeedforward](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XAbKy_l4y_-i"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['feedforward_cls'] = modeling.layers.GatedFeedforward\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = modeling.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder_with_gated_feedforward = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_gated_feedforward)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `gate` from GatedFeedforward exists.\n",
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "a_8NWUhkzeAq"
},
"source": [
"### Build a new Encoder using building blocks from KerasBERT.\n",
"\n",
"Finally, you could also build a new encoder using building blocks in the modeling library.\n",
"\n",
"See [AlbertTransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_transformer_encoder.py) as an example:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xsiA3RzUzmUM"
},
"outputs": [],
"source": [
"albert_encoder = modeling.networks.AlbertTransformerEncoder(**cfg)\n",
"classifier_model = build_classifier(albert_encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MeidDfhlHKSO"
},
"source": [
"Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Uv_juT22HERW"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "80xnUmoI7fBX"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "8nvTnfs6Q692"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WmfcMK5P5C1G"
},
"source": [
"# Introduction to the TensorFlow Models NLP library"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cH-oJ8R6AHMK"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0H_EFIhq4-MJ"
},
"source": [
"## Learning objectives\n",
"\n",
"In this Colab notebook, you will learn how to build transformer-based models for common NLP tasks including pretraining, span labelling and classification using the building blocks from [NLP modeling library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2N97-dps_nUk"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "459ygAVl_rg0"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Y-qGkdh6_sZc"
},
"outputs": [],
"source": [
"!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "e4huSSwyAG_5"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jqYXqtjBAJd9"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "djBQWjvy-60Y"
},
"source": [
"## BERT pretraining model\n",
"\n",
"BERT ([Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)) introduced the method of pre-training language representations on a large text corpus and then using that model for downstream NLP tasks.\n",
"\n",
"In this section, we will learn how to build a model to pretrain BERT on the masked language modeling task and next sentence prediction task. For simplicity, we only show the minimum example and use dummy data."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MKuHVlsCHmiq"
},
"source": [
"### Build a `BertPretrainer` model wrapping `TransformerEncoder`\n",
"\n",
"The [TransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/transformer_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n",
"\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "EXkcXz-9BwB3"
},
"outputs": [],
"source": [
"# Build a small transformer network.\n",
"vocab_size = 100\n",
"sequence_length = 16\n",
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0NH5irV5KTMS"
},
"source": [
"Inspecting the encoder, we see it contains few embedding layers, stacked `Transformer` layers and are connected to three input layers:\n",
"\n",
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lZNoZkBrIoff"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "o7eFOZXiIl-b"
},
"outputs": [],
"source": [
"# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "d5h5HT7gNHx_"
},
"source": [
"Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `Classification` heads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2tcNfm03IBF7"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "F2oHrXGUIS0M"
},
"outputs": [],
"source": [
"# We can feed some dummy data to get masked language model and sentence output.\n",
"batch_size = 2\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"masked_lm_positions_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"\n",
"outputs = bert_pretrainer(\n",
" [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
"lm_output = outputs[\"masked_lm\"]\n",
"sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n",
"print(sentence_output)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bnx3UCHniCS5"
},
"source": [
"### Compute loss\n",
"Next, we can use `lm_output` and `sentence_output` to compute `loss`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "k30H4Q86f52x"
},
"outputs": [],
"source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
"\n",
"mlm_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=masked_lm_ids_data,\n",
" predictions=lm_output,\n",
" weights=masked_lm_weights_data)\n",
"sentence_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=next_sentence_labels_data,\n",
" predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wrmSs8GjHxVw"
},
"source": [
"With the loss, you can optimize the model.\n",
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "k8cQVFvBCV4s"
},
"source": [
"## Span labeling model\n",
"\n",
"Span labeling is the task to assign labels to a span of the text, for example, label a span of text as the answer of a given question.\n",
"\n",
"In this section, we will learn how to build a span labeling model. Again, we use dummy data for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xrLLEWpfknUW"
},
"source": [
"### Build a BertSpanLabeler wrapping TransformerEncoder\n",
"\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n",
"Note that `BertSpanLabeler` wraps a `TransformerEncoder`, the weights of which can be restored from the above pretraining model.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "B941M4iUCejO"
},
"outputs": [],
"source": [
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QpB9pgj4PpMg"
},
"source": [
"Inspecting the `bert_span_labeler`, we see it wraps the encoder with additional `SpanLabeling` that outputs `start_position` and `end_postion`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RbqRNJCLJu4H"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fUf1vRxZJwio"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n",
"print(end_logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WqhgQaN1lt-G"
},
"source": [
"### Compute loss\n",
"With `start_logits` and `end_logits`, we can compute loss:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "waqs6azNl3Nn"
},
"outputs": [],
"source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"\n",
"start_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" start_positions, start_logits, from_logits=True)\n",
"end_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" end_positions, end_logits, from_logits=True)\n",
"\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Zdf03YtZmd_d"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_squad.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_squad.py) for the full example."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0A1XnGSTChg9"
},
"source": [
"## Classification model\n",
"\n",
"In the last section, we show how to build a text classification model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MSK8OpZgnQa9"
},
"source": [
"### Build a BertClassifier model wrapping TransformerEncoder\n",
"\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a simple token classification model containing a single classification head using the `TokenClassification` network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cXXCsffkCphk"
},
"outputs": [],
"source": [
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n",
" network, num_classes=num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8tZKueKYP4bB"
},
"source": [
"Inspecting the `bert_classifier`, we see it wraps the `encoder` with additional `Classification` head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "snlutm9ZJgEZ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "yyHPHsqBJkCz"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "w--a2mg4nzKm"
},
"source": [
"### Compute loss\n",
"\n",
"With `logits`, we can compute `loss`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9X0S1DoFn_5Q"
},
"outputs": [],
"source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n",
"loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mzBqOylZo3og"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
...@@ -37,17 +37,29 @@ class Task(tf.Module): ...@@ -37,17 +37,29 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs. # Special keys in train/validate step returned logs.
loss = "loss" loss = "loss"
def __init__(self, params: cfg.TaskConfig): def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self._task_config = params self._task_config = params
self._logging_dir = logging_dir
@property @property
def task_config(self) -> cfg.TaskConfig: def task_config(self) -> cfg.TaskConfig:
return self._task_config return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model. This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir. checkpoint, saved under a directory other than the model_dir.
...@@ -59,7 +71,7 @@ class Task(tf.Module): ...@@ -59,7 +71,7 @@ class Task(tf.Module):
@abc.abstractmethod @abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates model architecture.
Returns: Returns:
A model instance. A model instance.
...@@ -107,6 +119,7 @@ class Task(tf.Module): ...@@ -107,6 +119,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions. """Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size. Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args: Args:
params: hyperparams to create input pipelines. params: hyperparams to create input pipelines.
...@@ -122,7 +135,7 @@ class Task(tf.Module): ...@@ -122,7 +135,7 @@ class Task(tf.Module):
Args: Args:
labels: optional label tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
...@@ -172,6 +185,8 @@ class Task(tf.Module): ...@@ -172,6 +185,8 @@ class Task(tf.Module):
metrics=None): metrics=None):
"""Does forward and backward. """Does forward and backward.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
model: the model, forward pass definition. model: the model, forward pass definition.
...@@ -217,7 +232,9 @@ class Task(tf.Module): ...@@ -217,7 +232,9 @@ class Task(tf.Module):
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validation step.
With distribution strategies, this method runs on devices.
Args: Args:
inputs: a dictionary of input tensors. inputs: a dictionary of input tensors.
...@@ -244,7 +261,17 @@ class Task(tf.Module): ...@@ -244,7 +261,17 @@ class Task(tf.Module):
return logs return logs
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return model(inputs, training=False) return model(inputs, training=False)
def aggregate_logs(self, state, step_logs): def aggregate_logs(self, state, step_logs):
......
...@@ -171,6 +171,9 @@ class InputReader: ...@@ -171,6 +171,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised, as_supervised=self._tfds_as_supervised,
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
@property @property
......
...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict): ...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config subconfig_type = Config
if k in cls.__annotations__: if k in cls.__annotations__:
# Directly Config subtype. # Directly Config subtype.
type_annotation = cls.__annotations__[k] type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)): issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k] subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else: else:
# Check if the field is a sequence of subtypes. # Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None)) field_type = getattr(type_annotation, '__origin__', type(None))
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Union from typing import Optional, Union
import dataclasses import dataclasses
...@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config): ...@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly. run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance. persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
""" """
distribution_strategy: str = "mirrored" distribution_strategy: str = "mirrored"
enable_xla: bool = False enable_xla: bool = False
...@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config): ...@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
...@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config): ...@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval. eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop. steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary. summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints. checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep. max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely. checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
""" """
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
train_steps: int = 0 # Orbit settings.
validation_steps: Optional[int] = None train_tf_while_loop: bool = True
validation_interval: int = 100 train_tf_function: bool = True
eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000 steps_per_loop: int = 1000
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None continuous_eval_timeout: Optional[int] = None
train_tf_while_loop: bool = True # Train/Eval routines.
train_tf_function: bool = True train_steps: int = 0
eval_tf_function: bool = True validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -20,6 +20,20 @@ import dataclasses ...@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass @dataclasses.dataclass
class StepwiseLrConfig(base_config.Config): class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay. """Configuration for stepwise learning rate decay.
......
...@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig): ...@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes: Attributes:
type: 'str', type of lr schedule to be used, on the of fields below. type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config. stepwise: stepwise learning rate config.
exponential: exponential learning rate config. exponential: exponential learning rate config.
polynomial: polynomial learning rate config. polynomial: polynomial learning rate config.
cosine: cosine learning rate config. cosine: cosine learning rate config.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig() stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig() exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig() polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
...@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config): ...@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer. decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer. nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer. momentum: momentum for SGD optimizer.
""" """
name: str = "SGD" name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0 decay: float = 0.0
nesterov: bool = False nesterov: bool = False
momentum: float = 0.0 momentum: float = 0.0
...@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config): ...@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer. rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer. momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability. epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not. centered: Whether to normalize gradients or not.
""" """
name: str = "RMSprop" name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9 rho: float = 0.9
momentum: float = 0.0 momentum: float = 0.0
epsilon: float = 1e-7 epsilon: float = 1e-7
...@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config): ...@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer. epsilon: epsilon value used for numerical stability in Adam optimizer.
...@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config): ...@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond". the paper "On the Convergence of Adam and beyond".
""" """
name: str = "Adam" name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer. epsilon: epsilon value used for numerical stability in the optimizer.
...@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay. include in weight decay.
""" """
name: str = "AdamWeightDecay" name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config): ...@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer. epsilon: epsilon value used for numerical stability in LAMB optimizer.
...@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config): ...@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded. be excluded.
""" """
name: str = "LAMB" name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-6 epsilon: float = 1e-6
......
...@@ -60,7 +60,7 @@ class OptimizerFactory(object): ...@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -88,12 +88,15 @@ class OptimizerFactory(object): ...@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get() self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type self._optimizer_type = config.optimizer.type
if self._optimizer_config is None: if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified') raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get() self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get() self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type self._warmup_type = config.warmup.type
...@@ -101,18 +104,15 @@ class OptimizerFactory(object): ...@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate. """Build learning rate.
Builds learning rate from config. Learning rate schedule is built according Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer to the learning rate config. If learning rate type is consant,
learning rate is returned. lr_config.learning_rate is returned.
Returns: Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate schedule defined, optimizer_config.learning_rate is learning rate type is consant, lr_config.learning_rate is returned.
returned.
""" """
if self._lr_type == 'constant':
# TODO(arashwan): Explore if we want to only allow explicit const lr sched. lr = self._lr_config.learning_rate
if not self._lr_config:
lr = self._optimizer_config.learning_rate
else: else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict()) lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
...@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': optimizer_type 'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
} }
} }
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type] optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config() expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
...@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
def test_stepwise_lr_schedule(self): def test_stepwise_lr_schedule(self):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'exponential', 'type': 'exponential',
...@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'polynomial', 'type': 'polynomial',
...@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
...@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}, },
'warmup': { 'warmup': {
'type': 'linear', 'type': 'linear',
...@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment