Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
......@@ -10,24 +10,21 @@ can take full advantage of TensorFlow for their research and product development
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News |
|------|------|
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1
| July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
| May 19, 2020 | [MobileDets: Searching for Object Detection Architectures for Mobile Accelerators](https://github.com/tensorflow/models/tree/master/research/object_detection#may-19th-2020) released |
| May 7, 2020 | [MnasFPN with MobileNet-V2 backbone](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#mobile-models) released for object detection |
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## [Milestones](https://github.com/tensorflow/models/milestones)
| Date | Milestone |
|------|-----------|
| July 8, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
......
......@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Segmentation
| Model | Paper | Features | Maintainer |
......
......@@ -17,11 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models.
In the near future, we will add:
* State-of-the-art language understanding models:
More members in Transformer family
* Start-of-the-art image classification models:
EfficientNet, MnasNet, and variants
* A set of excellent objection detection models.
* State-of-the-art language understanding models.
* State-of-the-art image classification models.
* State-of-the-art objection detection and instance segmentation models.
## Table of Contents
......@@ -52,6 +50,7 @@ In the near future, we will add:
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing
......
......@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 2x2 TPU for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 2x2 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 4x4 TPU for 10000 steps."""
......@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 4x4 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 8x8 TPU for 10000 steps."""
......
......@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
**kwargs):
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__(
output_dir=output_dir,
default_flags=self.default_flags,
flag_methods=self.flag_methods)
flag_methods=self.flag_methods,
**kwargs)
def _report_benchmark(self,
stats,
......@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
def __init__(self, output_dir=None, default_flags=None, **kwargs):
flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
default_flags=default_flags,
**kwargs)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self):
......@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False
FLAGS.data_dir = 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def benchmark_2x2_tpu_bf16(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_bf16_mlir(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16_mlir')
self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self):
......@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common()
FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
......@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common()
FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16_mlir')
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark()
......@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20
def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=def_flags, **kwargs)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
......@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20
def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=def_flags, **kwargs)
if __name__ == '__main__':
......
......@@ -271,6 +271,61 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['architecture']['backbone'] = 'spinenet'
params['architecture']['multilevel_features'] = 'identity'
params['architecture']['use_bfloat16'] = False
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
params['train']['checkpoint']['path'] = ''
FLAGS.model_dir = self._get_model_dir(
'real_benchmark_2x2_tpu_spinenet_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
if __name__ == '__main__':
tf.test.main()
......@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS
......@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and
require the same FLAG setup.
"""
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None):
self._set_data_files(root_data_dir=root_data_dir)
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, root_data_dir=None, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if tpu_run:
root_data_dir = TPU_DATA_DIR
elif GPU_DATA_DIR is not None:
root_data_dir = GPU_DATA_DIR
root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de')
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_file_flags(self):
"""Sets the FLAGS for the data files."""
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
......@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 2048
FLAGS.train_steps = 1000
......@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096
FLAGS.train_steps = 100000
......@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
......@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000
......@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
......@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
......@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals.
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
......@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12
......@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
......@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet."
"""
self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8
FLAGS.static_batch = True
......@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu)
def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
def _set_df_common(self):
self._set_data_files(tpu_run=True)
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 6144
FLAGS.enable_checkpointing = False
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
......@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
......@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
......
......@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {}
......
......@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
......@@ -104,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -128,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT"
},
"source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:"
"You can get a pre-trained BERT encoder from [TensorFlow Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2):"
]
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -252,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -267,7 +267,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -290,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -313,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -336,7 +336,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -376,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -404,7 +404,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -446,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -469,7 +469,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -490,7 +490,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -514,7 +514,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -562,7 +562,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -587,7 +587,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -617,7 +617,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -661,7 +661,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -691,7 +691,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -737,7 +737,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -769,7 +769,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -793,7 +793,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -816,7 +816,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -845,7 +845,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -870,7 +870,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -908,7 +908,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -943,7 +943,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -986,7 +986,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1023,7 +1023,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1055,7 +1055,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1071,7 +1071,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1096,7 +1096,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1110,7 +1110,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1176,7 +1176,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1201,7 +1201,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1240,7 +1240,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1273,7 +1273,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1306,7 +1306,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1351,7 +1351,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1379,7 +1379,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1406,17 +1406,44 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
"id": "GDWrHm0BGpbX"
},
"outputs": [],
"source": [
"# Note: 350MB download.\n",
"import tensorflow_hub as hub\n",
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n",
"import tensorflow_hub as hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
......@@ -1433,7 +1460,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1466,7 +1493,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1491,7 +1518,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1504,7 +1531,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1545,7 +1572,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1569,7 +1596,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1592,7 +1619,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1617,7 +1644,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1643,7 +1670,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1661,7 +1688,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1688,7 +1715,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1714,7 +1741,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1733,7 +1760,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1761,7 +1788,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......@@ -1795,7 +1822,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Bp8t2AI8i7uP"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "rxPj2Lsni9O4"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6xS-9i5DrRvO"
},
"source": [
"# Customizing a Transformer Encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Mwb9uw1cDXsa"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iLrcV4IyrcGX"
},
"source": [
"## Learning objectives\n",
"\n",
"The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n",
"\n",
"The [TransformEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YYxdyoWgsl8t"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fEJSFutUsn_h"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "thsKZDjhswhR"
},
"outputs": [],
"source": [
"!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hpf7JPCVsqtv"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "my4dp-RMssQe"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from official.modeling import activations\n",
"from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vjDmVsFfs85n"
},
"source": [
"## Canonical BERT encoder\n",
"\n",
"Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `BertClassifier` for classification task."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Oav8sbgstWc-"
},
"outputs": [],
"source": [
"cfg = {\n",
" \"vocab_size\": 100,\n",
" \"hidden_size\": 32,\n",
" \"num_layers\": 3,\n",
" \"num_attention_heads\": 4,\n",
" \"intermediate_size\": 64,\n",
" \"activation\": activations.gelu,\n",
" \"dropout_rate\": 0.1,\n",
" \"attention_dropout_rate\": 0.1,\n",
" \"sequence_length\": 16,\n",
" \"type_vocab_size\": 2,\n",
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
"}\n",
"bert_encoder = modeling.networks.TransformerEncoder(**cfg)\n",
"\n",
"def build_classifier(bert_encoder):\n",
" return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n",
"\n",
"canonical_classifier_model = build_classifier(bert_encoder)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qe2UWI6_tsHo"
},
"source": [
"`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb). We skip the code that trains the model here.\n",
"\n",
"After training, we can apply the model to do prediction.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "csED2d-Yt5h6"
},
"outputs": [],
"source": [
"def predict(model):\n",
" batch_size = 3\n",
" np.random.seed(0)\n",
" word_ids = np.random.randint(\n",
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n",
" mask = np.random.randint(2, size=(batch_size, cfg[\"sequence_length\"]))\n",
" type_ids = np.random.randint(\n",
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n",
" print(model([word_ids, mask, type_ids], training=False))\n",
"\n",
"predict(canonical_classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PzKStEK9t_Pb"
},
"source": [
"## Customize BERT encoder\n",
"\n",
"One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rmwQfhj6fmKz"
},
"source": [
"We provide easy ways to customize each of those components via (1)\n",
"[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xsMgEVHAui11"
},
"source": [
"### Use EncoderScaffold\n",
"\n",
"`EncoderScaffold` allows users to provide a custom embedding subnetwork\n",
" (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-JBabpa2AOz8"
},
"source": [
"#### Without Customization\n",
"\n",
"Without any customization, `EncoderScaffold` behaves the same the canonical `TransformerEncoder`.\n",
"\n",
"As shown in the following example, `EncoderScaffold` can load `TransformerEncoder`'s weights and output the same values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ktNzKuVByZQf"
},
"outputs": [],
"source": [
"default_hidden_cfg = dict(\n",
" num_attention_heads=cfg[\"num_attention_heads\"],\n",
" intermediate_size=cfg[\"intermediate_size\"],\n",
" intermediate_activation=activations.gelu,\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n",
" kernel_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n",
"default_embedding_cfg = dict(\n",
" vocab_size=cfg[\"vocab_size\"],\n",
" type_vocab_size=cfg[\"type_vocab_size\"],\n",
" hidden_size=cfg[\"hidden_size\"],\n",
" seq_length=cfg[\"sequence_length\"],\n",
" initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" max_seq_length=cfg[\"sequence_length\"],\n",
")\n",
"default_kwargs = dict(\n",
" hidden_cfg=default_hidden_cfg,\n",
" embedding_cfg=default_embedding_cfg,\n",
" num_hidden_instances=cfg[\"num_layers\"],\n",
" pooled_output_dim=cfg[\"hidden_size\"],\n",
" return_all_layer_outputs=True,\n",
" pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n",
"encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n",
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
"classifier_model_from_encoder_scaffold.set_weights(\n",
" canonical_classifier_model.get_weights())\n",
"predict(classifier_model_from_encoder_scaffold)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sMaUmLyIuwcs"
},
"source": [
"#### Customize Embedding\n",
"\n",
"Next, we show how to use a customized embedding network.\n",
"\n",
"We firstly build an embedding network that will replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LTinnaG6vcsw"
},
"outputs": [],
"source": [
"word_ids = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
"mask = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
"embedding_layer = modeling.layers.OnDeviceEmbedding(\n",
" vocab_size=cfg['vocab_size'],\n",
" embedding_width=cfg['hidden_size'],\n",
" initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
" name=\"word_embeddings\")\n",
"word_embeddings = embedding_layer(word_ids)\n",
"attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n",
"new_embedding_network = tf.keras.Model([word_ids, mask],\n",
" [word_embeddings, attention_mask])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HN7_yu-6O3qI"
},
"source": [
"Inspecting `new_embedding_network`, we can see it takes two inputs:\n",
"`input_word_ids` and `input_mask`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fO9zKFE4OpHp"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9cOaGQHLv12W"
},
"source": [
"We then can build a new encoder using the above `new_embedding_network`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "mtFDMNf2vIl9"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use new embedding network.\n",
"kwargs['embedding_cls'] = new_embedding_network\n",
"kwargs['embedding_data'] = embedding_layer.embeddings\n",
"\n",
"encoder_with_customized_embedding = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_customized_embedding)\n",
"# ... Train the model ...\n",
"print(classifier_model.inputs)\n",
"\n",
"# Assert that there are only two inputs.\n",
"assert len(classifier_model.inputs) == 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Z73ZQDtmwg9K"
},
"source": [
"#### Customized Transformer\n",
"\n",
"User can also override the [hidden_cls](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py#L103) argument in `EncoderScaffold`'s constructor to employ a customized Transformer layer.\n",
"\n",
"See [ReZeroTransformer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n",
"\n",
"Following is an example of using `ReZeroTransformer`:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uAIarLZgw6pA"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use ReZeroTransformer.\n",
"kwargs['hidden_cls'] = modeling.layers.ReZeroTransformer\n",
"\n",
"encoder_with_rezero_transformer = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_rezero_transformer)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6PMHFdvnxvR0"
},
"source": [
"### Use [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)\n",
"\n",
"The above method of customizing `Transformer` requires rewriting the whole `Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py) can be used.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "D6FejlgwyAy_"
},
"source": [
"#### Customize Attention Layer\n",
"\n",
"User can also override the [attention_cls](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py#L45) argument in `TransformerScaffold`'s constructor to employ a customized Attention layer.\n",
"\n",
"See [TalkingHeadsAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n",
"\n",
"Following is an example of using [TalkingHeadsAttention](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nFrSMrZuyNeQ"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['attention_cls'] = modeling.layers.TalkingHeadsAttention\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = modeling.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kuEJcTyByVvI"
},
"source": [
"#### Customize Feedforward Layer\n",
"\n",
"Similiarly, one could also customize the feedforward layer.\n",
"\n",
"See [GatedFeedforward](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n",
"\n",
"Following is an example of using [GatedFeedforward](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XAbKy_l4y_-i"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['feedforward_cls'] = modeling.layers.GatedFeedforward\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = modeling.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder_with_gated_feedforward = modeling.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_gated_feedforward)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `gate` from GatedFeedforward exists.\n",
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "a_8NWUhkzeAq"
},
"source": [
"### Build a new Encoder using building blocks from KerasBERT.\n",
"\n",
"Finally, you could also build a new encoder using building blocks in the modeling library.\n",
"\n",
"See [AlbertTransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_transformer_encoder.py) as an example:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xsiA3RzUzmUM"
},
"outputs": [],
"source": [
"albert_encoder = modeling.networks.AlbertTransformerEncoder(**cfg)\n",
"classifier_model = build_classifier(albert_encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MeidDfhlHKSO"
},
"source": [
"Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Uv_juT22HERW"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "80xnUmoI7fBX"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "8nvTnfs6Q692"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WmfcMK5P5C1G"
},
"source": [
"# Introduction to the TensorFlow Models NLP library"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cH-oJ8R6AHMK"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0H_EFIhq4-MJ"
},
"source": [
"## Learning objectives\n",
"\n",
"In this Colab notebook, you will learn how to build transformer-based models for common NLP tasks including pretraining, span labelling and classification using the building blocks from [NLP modeling library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2N97-dps_nUk"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "459ygAVl_rg0"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-nightly` is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Y-qGkdh6_sZc"
},
"outputs": [],
"source": [
"!pip install -q tf-nightly\n",
"!pip install -q tf-models-nightly"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "e4huSSwyAG_5"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jqYXqtjBAJd9"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "djBQWjvy-60Y"
},
"source": [
"## BERT pretraining model\n",
"\n",
"BERT ([Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)) introduced the method of pre-training language representations on a large text corpus and then using that model for downstream NLP tasks.\n",
"\n",
"In this section, we will learn how to build a model to pretrain BERT on the masked language modeling task and next sentence prediction task. For simplicity, we only show the minimum example and use dummy data."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MKuHVlsCHmiq"
},
"source": [
"### Build a `BertPretrainer` model wrapping `TransformerEncoder`\n",
"\n",
"The [TransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/transformer_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n",
"\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "EXkcXz-9BwB3"
},
"outputs": [],
"source": [
"# Build a small transformer network.\n",
"vocab_size = 100\n",
"sequence_length = 16\n",
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0NH5irV5KTMS"
},
"source": [
"Inspecting the encoder, we see it contains few embedding layers, stacked `Transformer` layers and are connected to three input layers:\n",
"\n",
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lZNoZkBrIoff"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "o7eFOZXiIl-b"
},
"outputs": [],
"source": [
"# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "d5h5HT7gNHx_"
},
"source": [
"Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `Classification` heads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2tcNfm03IBF7"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "F2oHrXGUIS0M"
},
"outputs": [],
"source": [
"# We can feed some dummy data to get masked language model and sentence output.\n",
"batch_size = 2\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"masked_lm_positions_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"\n",
"outputs = bert_pretrainer(\n",
" [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
"lm_output = outputs[\"masked_lm\"]\n",
"sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n",
"print(sentence_output)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bnx3UCHniCS5"
},
"source": [
"### Compute loss\n",
"Next, we can use `lm_output` and `sentence_output` to compute `loss`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "k30H4Q86f52x"
},
"outputs": [],
"source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
"\n",
"mlm_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=masked_lm_ids_data,\n",
" predictions=lm_output,\n",
" weights=masked_lm_weights_data)\n",
"sentence_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=next_sentence_labels_data,\n",
" predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wrmSs8GjHxVw"
},
"source": [
"With the loss, you can optimize the model.\n",
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "k8cQVFvBCV4s"
},
"source": [
"## Span labeling model\n",
"\n",
"Span labeling is the task to assign labels to a span of the text, for example, label a span of text as the answer of a given question.\n",
"\n",
"In this section, we will learn how to build a span labeling model. Again, we use dummy data for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xrLLEWpfknUW"
},
"source": [
"### Build a BertSpanLabeler wrapping TransformerEncoder\n",
"\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n",
"Note that `BertSpanLabeler` wraps a `TransformerEncoder`, the weights of which can be restored from the above pretraining model.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "B941M4iUCejO"
},
"outputs": [],
"source": [
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QpB9pgj4PpMg"
},
"source": [
"Inspecting the `bert_span_labeler`, we see it wraps the encoder with additional `SpanLabeling` that outputs `start_position` and `end_postion`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RbqRNJCLJu4H"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fUf1vRxZJwio"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n",
"print(end_logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WqhgQaN1lt-G"
},
"source": [
"### Compute loss\n",
"With `start_logits` and `end_logits`, we can compute loss:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "waqs6azNl3Nn"
},
"outputs": [],
"source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"\n",
"start_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" start_positions, start_logits, from_logits=True)\n",
"end_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" end_positions, end_logits, from_logits=True)\n",
"\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Zdf03YtZmd_d"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_squad.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_squad.py) for the full example."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0A1XnGSTChg9"
},
"source": [
"## Classification model\n",
"\n",
"In the last section, we show how to build a text classification model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MSK8OpZgnQa9"
},
"source": [
"### Build a BertClassifier model wrapping TransformerEncoder\n",
"\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a simple token classification model containing a single classification head using the `TokenClassification` network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cXXCsffkCphk"
},
"outputs": [],
"source": [
"network = modeling.networks.TransformerEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n",
" network, num_classes=num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8tZKueKYP4bB"
},
"source": [
"Inspecting the `bert_classifier`, we see it wraps the `encoder` with additional `Classification` head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "snlutm9ZJgEZ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "yyHPHsqBJkCz"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "w--a2mg4nzKm"
},
"source": [
"### Compute loss\n",
"\n",
"With `logits`, we can compute `loss`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9X0S1DoFn_5Q"
},
"outputs": [],
"source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n",
"loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mzBqOylZo3og"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
......@@ -37,17 +37,29 @@ class Task(tf.Module):
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params: cfg.TaskConfig):
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
"""Task initialization.
Args:
params: cfg.TaskConfig instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
"""
self._task_config = params
self._logging_dir = logging_dir
@property
def task_config(self) -> cfg.TaskConfig:
return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
......@@ -59,7 +71,7 @@ class Task(tf.Module):
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
"""Creates model architecture.
Returns:
A model instance.
......@@ -107,6 +119,7 @@ class Task(tf.Module):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args:
params: hyperparams to create input pipelines.
......@@ -122,7 +135,7 @@ class Task(tf.Module):
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
......@@ -172,6 +185,8 @@ class Task(tf.Module):
metrics=None):
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
......@@ -217,7 +232,9 @@ class Task(tf.Module):
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
"""Validation step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
......@@ -244,7 +261,17 @@ class Task(tf.Module):
return logs
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step."""
"""Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
......
......@@ -171,6 +171,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset
@property
......
......@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config
if k in cls.__annotations__:
# Directly Config subtype.
type_annotation = cls.__annotations__[k]
type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k]
subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else:
# Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None))
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional, Union
import dataclasses
......@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
......@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
......@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 100
# Orbit settings.
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
# Train/Eval routines.
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass
......
......@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass
class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay.
......
......@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes:
type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config.
exponential: exponential learning rate config.
polynomial: polynomial learning rate config.
cosine: cosine learning rate config.
"""
type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
......@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0
nesterov: bool = False
momentum: float = 0.0
......@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
......@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
......@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond".
"""
name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
......@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
......@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay.
"""
name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
......@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer.
......@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded.
"""
name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-6
......
......@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type
if self._optimizer_config is None:
if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type
......@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate.
Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer
learning rate is returned.
to the learning rate config. If learning rate type is consant,
lr_config.learning_rate is returned.
Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no
learning rate schedule defined, optimizer_config.learning_rate is
returned.
tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate type is consant, lr_config.learning_rate is returned.
"""
# TODO(arashwan): Explore if we want to only allow explicit const lr sched.
if not self._lr_config:
lr = self._optimizer_config.learning_rate
if self._lr_type == 'constant':
lr = self._lr_config.learning_rate
else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
......@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}
}
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
......@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
def test_stepwise_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'exponential',
......@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'polynomial',
......@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'cosine',
......@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
},
'warmup': {
'type': 'linear',
......@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment