Unverified Commit 9b049266 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #3 from tensorflow/master

Updated
parents 63af6ba5 c5ad244e
......@@ -82,15 +82,27 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with tf.io.gfile.GFile(predictions_file, 'r') as reader:
return json.load(reader)
def _get_distribution_strategy(self, use_ds=True):
"""Gets the distribution strategy."""
if self.tpu:
def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
return distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=self.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
def _init_gpu_and_data_threads(self):
"""Set env variables before any TF calls."""
......@@ -102,12 +114,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training."""
def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
"""Runs BERT SQuAD training. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
strategy = self._get_distribution_strategy(ds_type)
run_squad.train_squad(
strategy=strategy,
......@@ -116,12 +128,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
custom_callbacks=[self.timer_callback])
@flagsaver.flagsaver
def _evaluate_squad(self, use_ds=True):
"""Runs BERT SQuAD evaluation."""
def _evaluate_squad(self, ds_type='mirrored'):
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
strategy = self._get_distribution_strategy(ds_type)
run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data)
......@@ -157,15 +169,15 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
run_eagerly=False,
ds_type='mirrored'):
"""Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......@@ -217,7 +229,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False)
self._run_and_report_benchmark(ds_type='off')
def benchmark_1_gpu_eager_no_dist_strat(self):
"""Tests BERT SQuAD model performance with 1 GPU with eager execution."""
......@@ -228,7 +240,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
'benchmark_1_gpu_eager_no_dist_strat_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True)
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_2_gpu(self):
"""Tests BERT SQuAD model performance with 2 GPUs."""
......@@ -420,12 +432,12 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
run_eagerly=False,
ds_type='mirrored'):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._evaluate_squad()
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
self._evaluate_squad(ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......@@ -445,7 +457,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True)
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_8_gpu(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
......@@ -518,8 +530,9 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._evaluate_squad()
self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
self._evaluate_squad(ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......@@ -595,7 +608,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
else:
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......
......@@ -53,6 +53,10 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods)
def _setup(self):
super(Resnet56KerasAccuracy, self)._setup()
FLAGS.use_tensor_lr = False
def benchmark_graph_1_gpu(self):
"""Test keras based model with Keras fit and distribution strategies."""
self._setup()
......@@ -439,6 +443,7 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
default_flags['use_synthetic_data'] = True
default_flags['train_steps'] = 110
default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=default_flags)
......@@ -453,6 +458,7 @@ class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
default_flags['train_steps'] = 110
default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=default_flags)
......
......@@ -71,7 +71,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.dtype = 'fp32'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
......@@ -87,7 +86,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
......@@ -104,7 +102,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self):
......@@ -120,7 +117,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True
# Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16(self):
......@@ -137,7 +133,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True
# Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
def benchmark_8_gpu_mlperf_like(self):
......@@ -179,7 +174,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.loss_scale = 'dynamic'
# Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark(top_1_min=0.736)
@benchmark_wrappers.enable_runtime_flags
......@@ -241,7 +235,6 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True
# Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
......@@ -472,7 +465,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -550,7 +542,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_graph_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -587,7 +578,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark()
......@@ -627,7 +617,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 24
self._run_and_report_benchmark()
......@@ -654,7 +643,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -670,7 +658,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -698,7 +685,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark()
......@@ -718,7 +704,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_fp16_tweaked_delay_measure')
FLAGS.batch_size = 256 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.train_steps = 310
self._run_and_report_benchmark()
......@@ -736,7 +721,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark()
......@@ -799,7 +783,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -815,7 +798,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -834,7 +816,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_tweaked_delay_measure')
FLAGS.batch_size = 256 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.train_steps = 310
self._run_and_report_benchmark()
......@@ -851,7 +832,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_graph_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -867,7 +847,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -963,7 +942,6 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def_flags['use_trivial_model'] = True
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
def_flags['use_tensor_lr'] = True
def_flags['dtype'] = 'fp16'
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 600
......@@ -1097,7 +1075,6 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = eager
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
......@@ -1161,7 +1138,6 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
FLAGS.enable_eager = eager
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app as absl_app
import numpy as np
from absl import flags
import tensorflow as tf
from official.benchmark.models import resnet_cifar_model
......@@ -64,6 +64,46 @@ def learning_rate_schedule(current_epoch,
return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.compat.v1.logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs.
......@@ -151,8 +191,18 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=cifar_preprocessing.parse_record)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
values=[initial_learning_rate] +
list(p[0] * initial_learning_rate for p in LR_SCHEDULE))
with strategy_scope:
optimizer = common.get_optimizer()
optimizer = common.get_optimizer(lr_schedule)
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
......@@ -173,11 +223,16 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule)
callbacks = common.get_callbacks(steps_per_epoch)
if not flags_obj.use_tensor_lr:
lr_callback = LearningRateBatchScheduler(
schedule=learning_rate_schedule,
batch_size=flags_obj.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
# if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to training performance."""
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale="dynamic"):
"""Configures optimizer object with performance options."""
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None):
"""Sets mix precision policy."""
if dtype == tf.float16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
raise ValueError("Unexpected dtype: %s" % dtype)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
"""Runs bert squad training."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=FLAGS.sp_model_file)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_sp)
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
......@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_labels):
"""Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
......
......@@ -66,10 +66,6 @@ def define_common_bert_flags():
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_enum(
'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
......@@ -92,9 +88,17 @@ def define_common_bert_flags():
)
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
......@@ -27,6 +27,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
......@@ -126,16 +127,12 @@ def run_bert_classifier(strategy,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
classifier_model.optimizer = optimization.create_optimizer(
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
classifier_model.optimizer)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model
# During distributed training, loss used for gradient computation is
......@@ -302,6 +299,7 @@ def run_bert(strategy,
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size']
......
......@@ -23,6 +23,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
......@@ -102,16 +103,12 @@ def run_customized_training(strategy,
"""Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer(
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
pretrain_model.optimizer)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
......@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
return run_customized_training(
strategy,
bert_config,
......
......@@ -19,361 +19,44 @@ from __future__ import division
from __future__ import print_function
import json
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import optimization
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be printed. '
'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
common_flags.define_common_bert_flags()
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
MODEL_CLASSES = {
'bert': (bert_configs.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
'albert': (albert_configs.AlbertConfig, squad_lib_sp,
tokenization.FullSentencePieceTokenizer),
}
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
return _loss_fn
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
squad_lib = MODEL_CLASSES[FLAGS.model_type][1]
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield squad_lib.RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
config_cls, squad_lib, tokenizer_cls = MODEL_CLASSES[FLAGS.model_type]
bert_config = config_cls.from_json_file(FLAGS.bert_config_file)
if tokenizer_cls == tokenization.FullTokenizer:
tokenizer = tokenizer_cls(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
else:
assert tokenizer_cls == tokenization.FullSentencePieceTokenizer
tokenizer = tokenizer_cls(sp_model_file=FLAGS.sp_model_file)
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_wp)
def export_squad(model_export_path, input_meta_data):
......@@ -386,16 +69,8 @@ def export_squad(model_export_path, input_meta_data):
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils
def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be '
'printed. A number of warnings are expected for a normal SQuAD '
'evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one '
'another.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
return _loss_fn
RawResult = collections.namedtuple('RawResult',
['unique_id', 'start_logits', 'end_logits'])
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
def export_squad(model_export_path, input_meta_data, bert_config):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......@@ -490,10 +490,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples,
all_features,
all_results,
......
......@@ -562,10 +562,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples,
all_features,
all_results,
......
......@@ -17,7 +17,6 @@
See README for description of setting the training schedule and evaluating the
BLEU score.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -30,19 +29,19 @@ from absl import flags
from absl import logging
import tensorflow as tf
# pylint: disable=g-bad-import-order
from official.modeling import performance
from official.nlp.transformer import compute_bleu
from official.nlp.transformer.utils import tokenizer
from official.nlp.transformer import data_pipeline
from official.nlp.transformer import metrics
from official.nlp.transformer import misc
from official.nlp.transformer import optimizer
from official.nlp.transformer import transformer
from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
INF = int(1e9)
BLEU_DIR = "bleu"
......@@ -180,21 +179,9 @@ class TransformerTask(object):
else:
logging.info("Not using any distribution strategy.")
if params["dtype"] == tf.float16:
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
# like this. What if multiple instances of TransformerTask are created?
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(
flags_obj, default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
performance.set_mixed_precision_policy(
params["dtype"],
flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
@property
def use_tpu(self):
......@@ -434,8 +421,6 @@ class TransformerTask(object):
def _create_optimizer(self):
"""Creates optimizer."""
params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"])
......@@ -445,18 +430,12 @@ class TransformerTask(object):
params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16:
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
opt,
loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"))
if self.flags_obj.fp16_implementation == "graph_rewrite":
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
opt = performance.configure_optimizer(
opt,
use_float16=params["dtype"] == tf.float16,
use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"))
return opt
......
......@@ -21,7 +21,6 @@ from __future__ import print_function
import math
import unittest
import mock
import numpy as np
import tensorflow as tf
......@@ -192,34 +191,34 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS +
......@@ -229,7 +228,7 @@ class NcfTest(tf.test.TestCase):
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1:
......@@ -242,7 +241,7 @@ class NcfTest(tf.test.TestCase):
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
'--dtype', 'fp16'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1:
......@@ -256,7 +255,7 @@ class NcfTest(tf.test.TestCase):
'--dtype', 'fp16',
'--keras_use_ctl'])
@mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2:
......
......@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj):
def get_loss_scale(flags_obj, default_for_fp16):
dtype = get_tf_dtype(flags_obj)
if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale
elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale)
elif flags_obj.dtype == "fp32":
elif dtype == tf.float32 or dtype == tf.bfloat16:
return 1 # No loss scaling is needed for fp32
else:
assert flags_obj.dtype == "fp16"
assert dtype == tf.float16
return default_for_fp16
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
import os
from absl import flags
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
......@@ -36,78 +35,6 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
]
def learning_rate_schedule(current_epoch,
current_batch,
steps_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
steps_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / steps_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = initial_lr * mult
else:
break
return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.compat.v1.logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule."""
......@@ -180,10 +107,8 @@ def get_optimizer(learning_rate=0.1):
return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def get_callbacks(
steps_per_epoch,
learning_rate_schedule_fn=None,
pruning_method=None,
enable_checkpoint_and_export=False,
model_dir=None):
......@@ -194,13 +119,6 @@ def get_callbacks(
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks = [time_callback]
if not FLAGS.use_tensor_lr and learning_rate_schedule_fn:
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
......@@ -317,7 +235,7 @@ def define_keras_flags(
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False,
flags.DEFINE_boolean(name='use_tensor_lr', default=True,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean(
name='enable_tensorboard', default=False,
......
......@@ -23,6 +23,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import performance
from official.staging.training import controller
from official.utils.flags import core as flags_core
from official.utils.logs import logger
......@@ -110,16 +111,7 @@ def run(flags_obj):
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.float16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_float16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
# This only affects GPU.
common.set_cudnn_batchnorm_mode()
......
......@@ -28,6 +28,7 @@ import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.benchmark.models import trivial_model
from official.modeling import performance
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
......@@ -65,17 +66,9 @@ def run(flags_obj):
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.float16:
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if not keras_utils.is_v2_0():
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
performance.set_mixed_precision_policy(
flags_core.get_tf_dtype(flags_obj),
flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
data_format = flags_obj.data_format
if data_format is None:
......@@ -155,23 +148,19 @@ def run(flags_obj):
dtype=dtype,
drop_remainder=drop_remainder)
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=common.LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=common.LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
learning_rate_schedule_fn = None
with strategy_scope:
if flags_obj.optimizer == 'resnet50_default':
optimizer = common.get_optimizer(lr_schedule)
learning_rate_schedule_fn = common.learning_rate_schedule
elif flags_obj.optimizer == 'mobilenet_default':
initial_learning_rate = \
flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
......@@ -248,7 +237,6 @@ def run(flags_obj):
callbacks = common.get_callbacks(
steps_per_epoch=steps_per_epoch,
learning_rate_schedule_fn=learning_rate_schedule_fn,
pruning_method=flags_obj.pruning_method,
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment