Commit 5a2cf36f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into newavarecords

parents 258ddfc3 a829e648
...@@ -10,11 +10,13 @@ can take full advantage of TensorFlow for their research and product development ...@@ -10,11 +10,13 @@ can take full advantage of TensorFlow for their research and product development
| [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read | | [official](official) | • A collection of example implementations for SOTA models using the latest TensorFlow 2's high-level APIs<br />• Officially maintained, supported, and kept up to date with the latest TensorFlow 2 APIs by TensorFlow<br />• Reasonably optimized for fast performance while still being easy to read |
| [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers | | [research](research) | • A collection of research model implementations in TensorFlow 1 or 2 by researchers<br />• Maintained and supported by researchers |
| [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 | | [community](community) | • A curated list of the GitHub repositories with machine learning models and implementations powered by TensorFlow 2 |
| [orbit](orbit) | • A flexible and lightweight library that users can easily use or fork when writing customized training loop code in TensorFlow 2.x. It seamlessly integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU). |
## [Announcements](https://github.com/tensorflow/models/wiki/Announcements) ## [Announcements](https://github.com/tensorflow/models/wiki/Announcements)
| Date | News | | Date | News |
|------|------| |------|------|
| July 10, 2020 | TensorFlow 2 meets the [Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) ([Blog](https://blog.tensorflow.org/2020/07/tensorflow-2-meets-object-detection-api.html)) |
| June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) | | June 30, 2020 | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://github.com/tensorflow/models/tree/master/official/vision/detection#train-a-spinenet-49-based-mask-r-cnn) released ([Tweet](https://twitter.com/GoogleAI/status/1278016712978264064)) |
| June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) | | June 17, 2020 | [Context R-CNN: Long Term Temporal Context for Per-Camera Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection#june-17th-2020) released ([Tweet](https://twitter.com/GoogleAI/status/1276571419422253057)) |
| May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released | | May 21, 2020 | [Unifying Deep Local and Global Features for Image Search (DELG)](https://github.com/tensorflow/models/tree/master/research/delf#delg) code released |
...@@ -23,12 +25,6 @@ can take full advantage of TensorFlow for their research and product development ...@@ -23,12 +25,6 @@ can take full advantage of TensorFlow for their research and product development
| May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 | | May 1, 2020 | [DELF: DEep Local Features](https://github.com/tensorflow/models/tree/master/research/delf) updated to support TensorFlow 2.1 |
| March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) | | March 31, 2020 | [Introducing the Model Garden for TensorFlow 2](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html) ([Tweet](https://twitter.com/TensorFlow/status/1245029834633297921)) |
## [Milestones](https://github.com/tensorflow/models/milestones)
| Date | Milestone |
|------|-----------|
| July 8, 2020 | [![GitHub milestone](https://img.shields.io/github/milestones/progress/tensorflow/models/1)](https://github.com/tensorflow/models/milestone/1) |
## Contributions ## Contributions
[![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation) [![help wanted:paper implementation](https://img.shields.io/github/issues/tensorflow/models/help%20wanted%3Apaper%20implementation)](https://github.com/tensorflow/models/labels/help%20wanted%3Apaper%20implementation)
......
...@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine ...@@ -20,6 +20,14 @@ This repository provides a curated list of the GitHub repositories with machine
| [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) | | [ResNet 50](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) | | [ResNet 50v1.5](https://github.com/IntelAI/models/tree/master/benchmarks/image_recognition/tensorflow/resnet50v1_5) | [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Object Detection
| Model | Paper | Features | Maintainer |
|-------|-------|----------|------------|
| [R-FCN](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/rfcn) | [R-FCN: Object Detection<br/>via Region-based Fully Convolutional Networks](https://arxiv.org/pdf/1605.06409) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-MobileNet](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-mobilenet) | [MobileNets: Efficient Convolutional Neural Networks<br/>for Mobile Vision Applications](https://arxiv.org/pdf/1704.04861) | • Int8 Inference<br/>• FP32 Inference | [Intel](https://github.com/IntelAI) |
| [SSD-ResNet34](https://github.com/IntelAI/models/tree/master/benchmarks/object_detection/tensorflow/ssd-resnet34) | [SSD: Single Shot MultiBox Detector](https://arxiv.org/pdf/1512.02325) | • Int8 Inference<br/>• FP32 Inference<br/>• FP32 Training | [Intel](https://github.com/IntelAI) |
### Segmentation ### Segmentation
| Model | Paper | Features | Maintainer | | Model | Paper | Features | Maintainer |
......
...@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build. ...@@ -17,12 +17,9 @@ with the same or improved speed and performance with each new TensorFlow build.
The team is actively developing new models. The team is actively developing new models.
In the near future, we will add: In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models.
More members in Transformer family * State-of-the-art image classification models.
* State-of-the-art image classification models: * State-of-the-art objection detection and instance segmentation models.
EfficientNet, MnasNet, and variants
* State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents ## Table of Contents
......
...@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -144,6 +144,39 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark(summary_path=summary_path, self._run_and_report_benchmark(summary_path=summary_path,
report_accuracy=True) report_accuracy=True)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 2x2 TPU for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 2x2 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.train_batch_size = 128
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_2x2_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden') @owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self): def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 4x4 TPU for 10000 steps.""" """Test bert pretraining with 4x4 TPU for 10000 steps."""
...@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -159,6 +192,22 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._run_and_report_benchmark( self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False) summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir(self):
"""Test bert pretraining with 4x4 TPU with MLIR for 10000 steps."""
self._setup()
self._specify_common_flags()
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 2
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_4x4_tpu_bf16_seq128_10k_steps_mlir')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
tf.config.experimental.enable_mlir_bridge()
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path, report_accuracy=False)
@owner_utils.Owner('tf-model-garden') @owner_utils.Owner('tf-model-garden')
def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self): def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
"""Test bert pretraining with 8x8 TPU for 10000 steps.""" """Test bert pretraining with 8x8 TPU for 10000 steps."""
......
...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS ...@@ -38,13 +38,18 @@ FLAGS = flags.FLAGS
class CtlBenchmark(PerfZeroBenchmark): class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing.""" """Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None): def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
**kwargs):
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__( super(CtlBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=self.default_flags, default_flags=self.default_flags,
flag_methods=self.flag_methods) flag_methods=self.flag_methods,
**kwargs)
def _report_benchmark(self, def _report_benchmark(self,
stats, stats,
...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -190,13 +195,14 @@ class Resnet50CtlAccuracy(CtlBenchmark):
class Resnet50CtlBenchmarkBase(CtlBenchmark): class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None, **kwargs):
flag_methods = [common.define_keras_flags] flag_methods = [common.define_keras_flags]
super(Resnet50CtlBenchmarkBase, self).__init__( super(Resnet50CtlBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags) default_flags=default_flags,
**kwargs)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self): def _run_and_report_benchmark(self):
...@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -381,12 +387,24 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.single_l2_loss_op = True FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False FLAGS.enable_checkpoint_and_export = False
FLAGS.data_dir = 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord'
def benchmark_2x2_tpu_bf16(self): def benchmark_2x2_tpu_bf16(self):
self._setup() self._setup()
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 1024 FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_bf16_mlir(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16_mlir')
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self): def benchmark_4x4_tpu_bf16(self):
...@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -394,6 +412,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler') @owner_utils.Owner('tf-graph-compiler')
...@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -403,6 +422,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
self._set_df_common() self._set_df_common()
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16_mlir')
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase): ...@@ -426,11 +446,11 @@ class Resnet50CtlBenchmarkSynth(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['use_synthetic_data'] = True def_flags['use_synthetic_data'] = True
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkSynth, self).__init__( super(Resnet50CtlBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
...@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase): ...@@ -441,11 +461,11 @@ class Resnet50CtlBenchmarkReal(Resnet50CtlBenchmarkBase):
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 110 def_flags['train_steps'] = 110
def_flags['steps_per_loop'] = 20 def_flags['steps_per_loop'] = 10
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50CtlBenchmarkReal, self).__init__( super(Resnet50CtlBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags) output_dir=output_dir, default_flags=def_flags, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy): ...@@ -271,6 +271,44 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
FLAGS.strategy_type = 'tpu' FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0) self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco')
FLAGS.strategy_type = 'tpu'
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_2x2_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 64
params['train']['total_steps'] = 1875 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_2x2_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver
def benchmark_4x4_tpu_coco_mlir(self):
"""Run RetinaNet model accuracy test with 4 TPUs."""
self._setup()
params = self._params()
params['train']['batch_size'] = 256
params['train']['total_steps'] = 469 # One epoch.
params['train']['iterations_per_loop'] = 500
FLAGS.model_dir = self._get_model_dir('real_benchmark_4x4_tpu_coco_mlir')
FLAGS.strategy_type = 'tpu'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(params, do_eval=False, warmup=0)
@flagsaver.flagsaver @flagsaver.flagsaver
def benchmark_2x2_tpu_spinenet_coco(self): def benchmark_2x2_tpu_spinenet_coco(self):
"""Run SpineNet with RetinaNet model accuracy test with 4 TPUs.""" """Run SpineNet with RetinaNet model accuracy test with 4 TPUs."""
......
...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc ...@@ -29,6 +29,8 @@ from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
TPU_DATA_DIR = 'gs://mlcompass-data/transformer'
GPU_DATA_DIR = os.getenv('TMPDIR')
TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official' TRANSFORMER_EN2DE_DATA_DIR_NAME = 'wmt32k-en2de-official'
EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014' EN2DE_2014_BLEU_DATA_DIR_NAME = 'newstest2014'
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -40,37 +42,54 @@ class TransformerBenchmark(PerfZeroBenchmark):
Code under test for the Transformer Keras models report the same data and Code under test for the Transformer Keras models report the same data and
require the same FLAG setup. require the same FLAG setup.
""" """
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None, tpu=None): flag_methods=None, tpu=None):
self._set_data_files(root_data_dir=root_data_dir)
if default_flags is None:
default_flags = {}
default_flags['data_dir'] = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file
super(TransformerBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
def _set_data_files(self, root_data_dir=None, tpu_run=False):
"""Sets train_data_dir, vocab_file, bleu_source and bleu_ref."""
# Use remote storage for TPU, remote storage for GPU if defined, else
# use environment provided root_data_dir.
if tpu_run:
root_data_dir = TPU_DATA_DIR
elif GPU_DATA_DIR is not None:
root_data_dir = GPU_DATA_DIR
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME) TRANSFORMER_EN2DE_DATA_DIR_NAME)
self.vocab_file = os.path.join(root_data_dir, self.vocab_file = os.path.join(root_data_dir,
TRANSFORMER_EN2DE_DATA_DIR_NAME, TRANSFORMER_EN2DE_DATA_DIR_NAME,
'vocab.ende.32768') 'vocab.ende.32768')
self.bleu_source = os.path.join(root_data_dir, self.bleu_source = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.en') 'newstest2014.en')
self.bleu_ref = os.path.join(root_data_dir, self.bleu_ref = os.path.join(root_data_dir,
EN2DE_2014_BLEU_DATA_DIR_NAME, EN2DE_2014_BLEU_DATA_DIR_NAME,
'newstest2014.de') 'newstest2014.de')
if default_flags is None: def _set_data_file_flags(self):
default_flags = {} """Sets the FLAGS for the data files."""
default_flags['data_dir'] = self.train_data_dir FLAGS.data_dir = self.train_data_dir
default_flags['vocab_file'] = self.vocab_file FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
super(TransformerBenchmark, self).__init__( FLAGS['bleu_source'].value = self.bleu_source
output_dir=output_dir, FLAGS['bleu_ref'].value = self.bleu_ref
default_flags=default_flags,
flag_methods=flag_methods,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -164,12 +183,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 2048 FLAGS.batch_size = 2048
FLAGS.train_steps = 1000 FLAGS.train_steps = 1000
...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -189,12 +204,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
not converge to the 27.3 BLEU (uncased) SOTA. not converge to the 27.3 BLEU (uncased) SOTA.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 1 FLAGS.num_gpus = 1
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -215,12 +226,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark): ...@@ -237,12 +244,8 @@ class TransformerBaseKerasAccuracy(TransformerBenchmark):
Should converge to 27.3 BLEU (uncased). This has not been confirmed yet. Should converge to 27.3 BLEU (uncased). This has not been confirmed yet.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.batch_size = 4096*8 FLAGS.batch_size = 4096*8
FLAGS.train_steps = 100000 FLAGS.train_steps = 100000
...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -284,12 +287,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Iterations are not epochs, an iteration is a number of steps between evals. Iterations are not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -306,12 +305,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -337,13 +332,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
not epochs, an iteration is a number of steps between evals. not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -360,14 +351,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.train_steps = 20000 * 12 FLAGS.train_steps = 20000 * 12
...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -384,13 +371,9 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -409,14 +392,10 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Should converge to 28.4 BLEU (uncased). This has not be verified yet."
""" """
self._setup() self._setup()
self._set_data_file_flags()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.data_dir = self.train_data_dir
FLAGS.vocab_file = self.vocab_file
# Sets values directly to avoid validation check.
FLAGS['bleu_source'].value = self.bleu_source
FLAGS['bleu_ref'].value = self.bleu_ref
FLAGS.param_set = 'big' FLAGS.param_set = 'big'
FLAGS.batch_size = 3072*8 FLAGS.batch_size = 3072*8
FLAGS.static_batch = True FLAGS.static_batch = True
...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -687,22 +666,41 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
root_data_dir=root_data_dir, batch_per_gpu=3072, root_data_dir=root_data_dir, batch_per_gpu=3072,
tpu=tpu) tpu=tpu)
def benchmark_2x2_tpu(self): def _set_df_common(self):
"""Port of former snaggletooth transformer_big model on 2x2.""" self._set_data_files(tpu_run=True)
self._setup() FLAGS.data_dir = self.train_data_dir
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu') FLAGS.vocab_file = self.vocab_file
FLAGS.distribution_strategy = 'tpu'
FLAGS.padded_decode = True
FLAGS.train_steps = 300 FLAGS.train_steps = 300
FLAGS.log_steps = 150 FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150 FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True FLAGS.static_batch = True
FLAGS.use_ctl = True FLAGS.use_ctl = True
FLAGS.batch_size = 6144 FLAGS.enable_checkpointing = False
FLAGS.max_length = 64 FLAGS.max_length = 64
FLAGS.decode_batch_size = 32 FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97 FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False def benchmark_2x2_tpu(self):
"""Port of former snaggletooth transformer_big model on 2x2."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.batch_size = 6144
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""Run transformer_big model on 2x2 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mlir')
FLAGS.batch_size = 6144
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -711,19 +709,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu(self): def benchmark_4x4_tpu(self):
"""Port of former GCP transformer_big model on 4x4.""" """Port of former GCP transformer_big model on 4x4."""
self._setup() self._setup()
self._set_df_common()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
self._run_and_report_benchmark( self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -733,19 +721,9 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
def benchmark_4x4_tpu_mlir(self): def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled.""" """Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup() self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu') self._set_df_common()
FLAGS.train_steps = 300 FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_mlir')
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576 FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge() tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark( self._run_and_report_benchmark(
......
...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark): ...@@ -93,8 +93,11 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
params = unet_training_lib.extract_params(FLAGS) params = unet_training_lib.extract_params(FLAGS)
strategy = unet_training_lib.create_distribution_strategy(params) strategy = unet_training_lib.create_distribution_strategy(params)
if params.use_bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy) tf.keras.mixed_precision.experimental.set_policy(policy)
stats = {} stats = {}
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {}, "colab": {},
...@@ -104,7 +104,7 @@ ...@@ -104,7 +104,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -128,7 +128,7 @@ ...@@ -128,7 +128,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -204,12 +204,12 @@ ...@@ -204,12 +204,12 @@
"id": "9uFskufsR2LT" "id": "9uFskufsR2LT"
}, },
"source": [ "source": [
"You can get a pre-trained BERT encoder from TensorFlow Hub here:" "You can get a pre-trained BERT encoder from [TensorFlow Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2):"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -252,7 +252,7 @@ ...@@ -252,7 +252,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -267,7 +267,7 @@ ...@@ -267,7 +267,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -290,7 +290,7 @@ ...@@ -290,7 +290,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -313,7 +313,7 @@ ...@@ -313,7 +313,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -336,7 +336,7 @@ ...@@ -336,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -376,7 +376,7 @@ ...@@ -376,7 +376,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -404,7 +404,7 @@ ...@@ -404,7 +404,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -446,7 +446,7 @@ ...@@ -446,7 +446,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -469,7 +469,7 @@ ...@@ -469,7 +469,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -490,7 +490,7 @@ ...@@ -490,7 +490,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -514,7 +514,7 @@ ...@@ -514,7 +514,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -562,7 +562,7 @@ ...@@ -562,7 +562,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -587,7 +587,7 @@ ...@@ -587,7 +587,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -617,7 +617,7 @@ ...@@ -617,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -661,7 +661,7 @@ ...@@ -661,7 +661,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -691,7 +691,7 @@ ...@@ -691,7 +691,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -737,7 +737,7 @@ ...@@ -737,7 +737,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -769,7 +769,7 @@ ...@@ -769,7 +769,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -793,7 +793,7 @@ ...@@ -793,7 +793,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -816,7 +816,7 @@ ...@@ -816,7 +816,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -845,7 +845,7 @@ ...@@ -845,7 +845,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -870,7 +870,7 @@ ...@@ -870,7 +870,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -908,7 +908,7 @@ ...@@ -908,7 +908,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -943,7 +943,7 @@ ...@@ -943,7 +943,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -986,7 +986,7 @@ ...@@ -986,7 +986,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1023,7 +1023,7 @@ ...@@ -1023,7 +1023,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1055,7 +1055,7 @@ ...@@ -1055,7 +1055,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1071,7 +1071,7 @@ ...@@ -1071,7 +1071,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1096,7 +1096,7 @@ ...@@ -1096,7 +1096,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1110,7 +1110,7 @@ ...@@ -1110,7 +1110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1176,7 +1176,7 @@ ...@@ -1176,7 +1176,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1201,7 +1201,7 @@ ...@@ -1201,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1240,7 +1240,7 @@ ...@@ -1240,7 +1240,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1273,7 +1273,7 @@ ...@@ -1273,7 +1273,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1306,7 +1306,7 @@ ...@@ -1306,7 +1306,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1351,7 +1351,7 @@ ...@@ -1351,7 +1351,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1379,7 +1379,7 @@ ...@@ -1379,7 +1379,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1406,17 +1406,44 @@ ...@@ -1406,17 +1406,44 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
"id": "lo6479At4sP1" "id": "GDWrHm0BGpbX"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Note: 350MB download.\n", "# Note: 350MB download.\n",
"import tensorflow_hub as hub\n", "import tensorflow_hub as hub"
"hub_encoder = hub.KerasLayer(hub_url_bert, trainable=True)\n", ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
"source": [
"hub_model_name = \"bert_en_uncased_L-12_H-768_A-12\" #@param [\"bert_en_uncased_L-24_H-1024_A-16\", \"bert_en_wwm_cased_L-24_H-1024_A-16\", \"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_wwm_uncased_L-24_H-1024_A-16\", \"bert_en_cased_L-24_H-1024_A-16\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_zh_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}\",\n",
" trainable=True)\n",
"\n", "\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")" "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
] ]
...@@ -1433,7 +1460,7 @@ ...@@ -1433,7 +1460,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1466,7 +1493,7 @@ ...@@ -1466,7 +1493,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1491,7 +1518,7 @@ ...@@ -1491,7 +1518,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1504,7 +1531,7 @@ ...@@ -1504,7 +1531,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1545,7 +1572,7 @@ ...@@ -1545,7 +1572,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1569,7 +1596,7 @@ ...@@ -1569,7 +1596,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1592,7 +1619,7 @@ ...@@ -1592,7 +1619,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1617,7 +1644,7 @@ ...@@ -1617,7 +1644,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1643,7 +1670,7 @@ ...@@ -1643,7 +1670,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1661,7 +1688,7 @@ ...@@ -1661,7 +1688,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1688,7 +1715,7 @@ ...@@ -1688,7 +1715,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1714,7 +1741,7 @@ ...@@ -1714,7 +1741,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1733,7 +1760,7 @@ ...@@ -1733,7 +1760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1761,7 +1788,7 @@ ...@@ -1761,7 +1788,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
...@@ -1795,7 +1822,7 @@ ...@@ -1795,7 +1822,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
......
This diff is collapsed.
This diff is collapsed.
...@@ -59,7 +59,7 @@ class Task(tf.Module): ...@@ -59,7 +59,7 @@ class Task(tf.Module):
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn. """A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model. This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir. checkpoint, saved under a directory other than the model_dir.
...@@ -71,7 +71,7 @@ class Task(tf.Module): ...@@ -71,7 +71,7 @@ class Task(tf.Module):
@abc.abstractmethod @abc.abstractmethod
def build_model(self) -> tf.keras.Model: def build_model(self) -> tf.keras.Model:
"""Creates the model architecture. """Creates model architecture.
Returns: Returns:
A model instance. A model instance.
...@@ -135,7 +135,7 @@ class Task(tf.Module): ...@@ -135,7 +135,7 @@ class Task(tf.Module):
Args: Args:
labels: optional label tensors. labels: optional label tensors.
model_outputs: a nested structure of output tensors. model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
...@@ -232,7 +232,7 @@ class Task(tf.Module): ...@@ -232,7 +232,7 @@ class Task(tf.Module):
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validation step.
With distribution strategies, this method runs on devices. With distribution strategies, this method runs on devices.
......
...@@ -171,6 +171,9 @@ class InputReader: ...@@ -171,6 +171,9 @@ class InputReader:
as_supervised=self._tfds_as_supervised, as_supervised=self._tfds_as_supervised,
decoders=decoders, decoders=decoders,
read_config=read_config) read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset return dataset
@property @property
......
...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict): ...@@ -126,10 +126,10 @@ class Config(params_dict.ParamsDict):
subconfig_type = Config subconfig_type = Config
if k in cls.__annotations__: if k in cls.__annotations__:
# Directly Config subtype. # Directly Config subtype.
type_annotation = cls.__annotations__[k] type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)): issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k] subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else: else:
# Check if the field is a sequence of subtypes. # Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None)) field_type = getattr(type_annotation, '__origin__', type(None))
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common configuration settings.""" """Common configuration settings."""
from typing import Optional, Union from typing import Optional, Union
import dataclasses import dataclasses
...@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config): ...@@ -111,6 +112,8 @@ class RuntimeConfig(base_config.Config):
run_eagerly: Whether or not to run the experiment eagerly. run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance. persistent mode for CuDNN batch norm kernel for improved GPU performance.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
""" """
distribution_strategy: str = "mirrored" distribution_strategy: str = "mirrored"
enable_xla: bool = False enable_xla: bool = False
...@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config): ...@@ -123,8 +126,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False batchnorm_spatial_persistent: bool = False
...@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config): ...@@ -172,23 +175,32 @@ class TrainerConfig(base_config.Config):
eval_tf_function: whether or not to use tf_function for eval. eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop. steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary. summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints. checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep. max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely. checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
""" """
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
train_steps: int = 0 # Orbit settings.
validation_steps: Optional[int] = None train_tf_while_loop: bool = True
validation_interval: int = 100 train_tf_function: bool = True
eval_tf_function: bool = True
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000 steps_per_loop: int = 1000
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None continuous_eval_timeout: Optional[int] = None
train_tf_while_loop: bool = True # Train/Eval routines.
train_tf_function: bool = True train_steps: int = 0
eval_tf_function: bool = True validation_steps: Optional[int] = None
validation_interval: int = 1000
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -20,6 +20,20 @@ import dataclasses ...@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass @dataclasses.dataclass
class StepwiseLrConfig(base_config.Config): class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay. """Configuration for stepwise learning rate decay.
......
...@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig): ...@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes: Attributes:
type: 'str', type of lr schedule to be used, on the of fields below. type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config. stepwise: stepwise learning rate config.
exponential: exponential learning rate config. exponential: exponential learning rate config.
polynomial: polynomial learning rate config. polynomial: polynomial learning rate config.
cosine: cosine learning rate config. cosine: cosine learning rate config.
""" """
type: Optional[str] = None type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig() stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig() exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig() polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
...@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config): ...@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer. decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer. nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer. momentum: momentum for SGD optimizer.
""" """
name: str = "SGD" name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0 decay: float = 0.0
nesterov: bool = False nesterov: bool = False
momentum: float = 0.0 momentum: float = 0.0
...@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config): ...@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer. rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer. momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability. epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not. centered: Whether to normalize gradients or not.
""" """
name: str = "RMSprop" name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9 rho: float = 0.9
momentum: float = 0.0 momentum: float = 0.0
epsilon: float = 1e-7 epsilon: float = 1e-7
...@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config): ...@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer. epsilon: epsilon value used for numerical stability in Adam optimizer.
...@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config): ...@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond". the paper "On the Convergence of Adam and beyond".
""" """
name: str = "Adam" name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer. epsilon: epsilon value used for numerical stability in the optimizer.
...@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config): ...@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay. include in weight decay.
""" """
name: str = "AdamWeightDecay" name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-07 epsilon: float = 1e-07
...@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config): ...@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes: Attributes:
name: name of the optimizer. name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments. beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments. beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer. epsilon: epsilon value used for numerical stability in LAMB optimizer.
...@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config): ...@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded. be excluded.
""" """
name: str = "LAMB" name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9 beta_1: float = 0.9
beta_2: float = 0.999 beta_2: float = 0.999
epsilon: float = 1e-6 epsilon: float = 1e-6
......
...@@ -60,7 +60,7 @@ class OptimizerFactory(object): ...@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -88,12 +88,15 @@ class OptimizerFactory(object): ...@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get() self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type self._optimizer_type = config.optimizer.type
if self._optimizer_config is None: if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified') raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get() self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get() self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type self._warmup_type = config.warmup.type
...@@ -101,18 +104,15 @@ class OptimizerFactory(object): ...@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate. """Build learning rate.
Builds learning rate from config. Learning rate schedule is built according Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer to the learning rate config. If learning rate type is consant,
learning rate is returned. lr_config.learning_rate is returned.
Returns: Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate schedule defined, optimizer_config.learning_rate is learning rate type is consant, lr_config.learning_rate is returned.
returned.
""" """
if self._lr_type == 'constant':
# TODO(arashwan): Explore if we want to only allow explicit const lr sched. lr = self._lr_config.learning_rate
if not self._lr_config:
lr = self._optimizer_config.learning_rate
else: else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict()) lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
...@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': optimizer_type 'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
} }
} }
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type] optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config() expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
...@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
def test_stepwise_lr_schedule(self): def test_stepwise_lr_schedule(self):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
...@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'exponential', 'type': 'exponential',
...@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'polynomial', 'type': 'polynomial',
...@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
...@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}, },
'warmup': { 'warmup': {
'type': 'linear', 'type': 'linear',
...@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'sgd', 'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9} 'sgd': {'momentum': 0.9}
}, },
'learning_rate': { 'learning_rate': {
'type': 'stepwise', 'type': 'stepwise',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment