"...en/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "86490d055be0db7ce0a33b0bd53e0d2907d2f92e"
Unverified Commit 9b049266 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #3 from tensorflow/master

Updated
parents 63af6ba5 c5ad244e
...@@ -82,15 +82,27 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -82,15 +82,27 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with tf.io.gfile.GFile(predictions_file, 'r') as reader: with tf.io.gfile.GFile(predictions_file, 'r') as reader:
return json.load(reader) return json.load(reader)
def _get_distribution_strategy(self, use_ds=True): def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.""" """Gets the distribution strategy.
if self.tpu:
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy( return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu) distribution_strategy='tpu', tpu_address=self.tpu)
else: elif ds_type == 'multi_worker_mirrored':
return distribution_utils.get_distribution_strategy( # Configures cluster spec for multi-worker distribution strategy.
distribution_strategy='mirrored' if use_ds else 'off', _ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
num_gpus=self.num_gpus) FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=self.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
def _init_gpu_and_data_threads(self): def _init_gpu_and_data_threads(self):
"""Set env variables before any TF calls.""" """Set env variables before any TF calls."""
...@@ -102,12 +114,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -102,12 +114,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
datasets_num_private_threads=FLAGS.datasets_num_private_threads) datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver @flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False): def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
"""Runs BERT SQuAD training.""" """Runs BERT SQuAD training. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads() self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds) strategy = self._get_distribution_strategy(ds_type)
run_squad.train_squad( run_squad.train_squad(
strategy=strategy, strategy=strategy,
...@@ -116,12 +128,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -116,12 +128,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
custom_callbacks=[self.timer_callback]) custom_callbacks=[self.timer_callback])
@flagsaver.flagsaver @flagsaver.flagsaver
def _evaluate_squad(self, use_ds=True): def _evaluate_squad(self, ds_type='mirrored'):
"""Runs BERT SQuAD evaluation.""" """Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads() self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds) strategy = self._get_distribution_strategy(ds_type)
run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data) run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data)
...@@ -157,15 +169,15 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -157,15 +169,15 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
use_ds=True, run_eagerly=False,
run_eagerly=False): ds_type='mirrored'):
"""Runs the benchmark and reports various metrics.""" """Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4: if FLAGS.train_batch_size <= 4:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else: else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec = time.time() start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly) self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file() summary = self._read_training_summary_from_file()
...@@ -217,7 +229,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -217,7 +229,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
FLAGS.train_batch_size = 4 FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False) self._run_and_report_benchmark(ds_type='off')
def benchmark_1_gpu_eager_no_dist_strat(self): def benchmark_1_gpu_eager_no_dist_strat(self):
"""Tests BERT SQuAD model performance with 1 GPU with eager execution.""" """Tests BERT SQuAD model performance with 1 GPU with eager execution."""
...@@ -228,7 +240,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -228,7 +240,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
'benchmark_1_gpu_eager_no_dist_strat_squad') 'benchmark_1_gpu_eager_no_dist_strat_squad')
FLAGS.train_batch_size = 4 FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True) self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_2_gpu(self): def benchmark_2_gpu(self):
"""Tests BERT SQuAD model performance with 2 GPUs.""" """Tests BERT SQuAD model performance with 2 GPUs."""
...@@ -420,12 +432,12 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -420,12 +432,12 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
use_ds=True, run_eagerly=False,
run_eagerly=False): ds_type='mirrored'):
"""Runs the benchmark and reports various metrics.""" """Runs the benchmark and reports various metrics."""
start_time_sec = time.time() start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly) self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
self._evaluate_squad() self._evaluate_squad(ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file() summary = self._read_training_summary_from_file()
...@@ -445,7 +457,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -445,7 +457,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
FLAGS.train_batch_size = 4 FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True) self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_8_gpu(self): def benchmark_8_gpu(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs.""" """Tests BERT SQuAD model accuracy with 8 GPUs."""
...@@ -518,8 +530,9 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase): ...@@ -518,8 +530,9 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
run_eagerly=False): run_eagerly=False):
"""Runs the benchmark and reports various metrics.""" """Runs the benchmark and reports various metrics."""
start_time_sec = time.time() start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly) self._train_squad(run_eagerly=run_eagerly,
self._evaluate_squad() ds_type='multi_worker_mirrored')
self._evaluate_squad(ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file() summary = self._read_training_summary_from_file()
...@@ -595,7 +608,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -595,7 +608,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
else: else:
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec = time.time() start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly) self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file() summary = self._read_training_summary_from_file()
......
...@@ -53,6 +53,10 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -53,6 +53,10 @@ class Resnet56KerasAccuracy(keras_benchmark.KerasBenchmark):
super(Resnet56KerasAccuracy, self).__init__( super(Resnet56KerasAccuracy, self).__init__(
output_dir=output_dir, flag_methods=flag_methods) output_dir=output_dir, flag_methods=flag_methods)
def _setup(self):
super(Resnet56KerasAccuracy, self)._setup()
FLAGS.use_tensor_lr = False
def benchmark_graph_1_gpu(self): def benchmark_graph_1_gpu(self):
"""Test keras based model with Keras fit and distribution strategies.""" """Test keras based model with Keras fit and distribution strategies."""
self._setup() self._setup()
...@@ -439,6 +443,7 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase): ...@@ -439,6 +443,7 @@ class Resnet56KerasBenchmarkSynth(Resnet56KerasBenchmarkBase):
default_flags['use_synthetic_data'] = True default_flags['use_synthetic_data'] = True
default_flags['train_steps'] = 110 default_flags['train_steps'] = 110
default_flags['log_steps'] = 10 default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkSynth, self).__init__( super(Resnet56KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=default_flags) output_dir=output_dir, default_flags=default_flags)
...@@ -453,6 +458,7 @@ class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase): ...@@ -453,6 +458,7 @@ class Resnet56KerasBenchmarkReal(Resnet56KerasBenchmarkBase):
default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME) default_flags['data_dir'] = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
default_flags['train_steps'] = 110 default_flags['train_steps'] = 110
default_flags['log_steps'] = 10 default_flags['log_steps'] = 10
default_flags['use_tensor_lr'] = False
super(Resnet56KerasBenchmarkReal, self).__init__( super(Resnet56KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=default_flags) output_dir=output_dir, default_flags=default_flags)
......
...@@ -71,7 +71,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -71,7 +71,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.epochs_between_evals = 10 FLAGS.epochs_between_evals = 10
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
FLAGS.dtype = 'fp32' FLAGS.dtype = 'fp32'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu(self): def benchmark_8_gpu(self):
...@@ -87,7 +86,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -87,7 +86,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self): def benchmark_8_gpu_amp(self):
...@@ -104,7 +102,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -104,7 +102,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_fp16(self): def benchmark_8_gpu_fp16(self):
...@@ -120,7 +117,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -120,7 +117,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
# Thread tuning to improve performance. # Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16(self): def benchmark_xla_8_gpu_fp16(self):
...@@ -137,7 +133,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -137,7 +133,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
# Thread tuning to improve performance. # Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_mlperf_like(self): def benchmark_8_gpu_mlperf_like(self):
...@@ -179,7 +174,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -179,7 +174,6 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
# Thread tuning to improve performance. # Thread tuning to improve performance.
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark(top_1_min=0.736) self._run_and_report_benchmark(top_1_min=0.736)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
...@@ -241,7 +235,6 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -241,7 +235,6 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
# Add some thread tunings to improve performance. # Add some thread tunings to improve performance.
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
FLAGS.use_tensor_lr = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
...@@ -472,7 +465,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -472,7 +465,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256 FLAGS.batch_size = 256
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -550,7 +542,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -550,7 +542,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_graph_xla_1_gpu_fp16_tweaked') 'benchmark_graph_xla_1_gpu_fp16_tweaked')
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
FLAGS.batch_size = 256 FLAGS.batch_size = 256
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -587,7 +578,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -587,7 +578,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.datasets_num_private_threads = 14 FLAGS.datasets_num_private_threads = 14
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -627,7 +617,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -627,7 +617,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_tweaked')
FLAGS.batch_size = 128 * 8 FLAGS.batch_size = 128 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 24 FLAGS.datasets_num_private_threads = 24
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -654,7 +643,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -654,7 +643,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -670,7 +658,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -670,7 +658,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_8_gpu_fp16_dynamic_tweaked') 'benchmark_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -698,7 +685,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -698,7 +685,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -718,7 +704,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -718,7 +704,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_fp16_tweaked_delay_measure') 'benchmark_xla_8_gpu_fp16_tweaked_delay_measure')
FLAGS.batch_size = 256 * 8 FLAGS.batch_size = 256 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.train_steps = 310 FLAGS.train_steps = 310
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -736,7 +721,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -736,7 +721,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_xla_8_gpu_fp16_dynamic_tweaked') 'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -799,7 +783,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -799,7 +783,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -815,7 +798,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -815,7 +798,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_tweaked') 'benchmark_graph_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -834,7 +816,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -834,7 +816,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_tweaked_delay_measure') 'benchmark_graph_xla_8_gpu_fp16_tweaked_delay_measure')
FLAGS.batch_size = 256 * 8 FLAGS.batch_size = 256 * 8
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.train_steps = 310 FLAGS.train_steps = 310
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -851,7 +832,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -851,7 +832,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
'benchmark_graph_8_gpu_fp16_dynamic_tweaked') 'benchmark_graph_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -867,7 +847,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -867,7 +847,6 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_graph_xla_8_gpu_fp16_dynamic_tweaked') 'benchmark_graph_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.use_tensor_lr = True
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -963,7 +942,6 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark): ...@@ -963,7 +942,6 @@ class TrivialKerasBenchmarkReal(keras_benchmark.KerasBenchmark):
def_flags['use_trivial_model'] = True def_flags['use_trivial_model'] = True
def_flags['skip_eval'] = True def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False def_flags['report_accuracy_metrics'] = False
def_flags['use_tensor_lr'] = True
def_flags['dtype'] = 'fp16' def_flags['dtype'] = 'fp16'
def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet') def_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
def_flags['train_steps'] = 600 def_flags['train_steps'] = 600
...@@ -1097,7 +1075,6 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -1097,7 +1075,6 @@ class Resnet50MultiWorkerKerasAccuracy(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = eager FLAGS.enable_eager = eager
FLAGS.enable_xla = False FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32 FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
...@@ -1161,7 +1138,6 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase): ...@@ -1161,7 +1138,6 @@ class Resnet50MultiWorkerKerasBenchmark(Resnet50KerasBenchmarkBase):
FLAGS.enable_eager = eager FLAGS.enable_eager = eager
FLAGS.enable_xla = False FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.use_tensor_lr = True
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32 FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app as absl_app import numpy as np
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.benchmark.models import resnet_cifar_model from official.benchmark.models import resnet_cifar_model
...@@ -64,6 +64,46 @@ def learning_rate_schedule(current_epoch, ...@@ -64,6 +64,46 @@ def learning_rate_schedule(current_epoch,
return learning_rate return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.compat.v1.logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
def run(flags_obj): def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs. """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
...@@ -151,8 +191,18 @@ def run(flags_obj): ...@@ -151,8 +191,18 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=cifar_preprocessing.parse_record) parse_record_fn=cifar_preprocessing.parse_record)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
lr_schedule = 0.1
if flags_obj.use_tensor_lr:
initial_learning_rate = common.BASE_LEARNING_RATE * flags_obj.batch_size / 128
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=list(p[1] * steps_per_epoch for p in LR_SCHEDULE),
values=[initial_learning_rate] +
list(p[0] * initial_learning_rate for p in LR_SCHEDULE))
with strategy_scope: with strategy_scope:
optimizer = common.get_optimizer() optimizer = common.get_optimizer(lr_schedule)
model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer # TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
...@@ -173,11 +223,16 @@ def run(flags_obj): ...@@ -173,11 +223,16 @@ def run(flags_obj):
if flags_obj.report_accuracy_metrics else None), if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly) run_eagerly=flags_obj.run_eagerly)
steps_per_epoch = (
cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
callbacks = common.get_callbacks(steps_per_epoch, learning_rate_schedule) callbacks = common.get_callbacks(steps_per_epoch)
if not flags_obj.use_tensor_lr:
lr_callback = LearningRateBatchScheduler(
schedule=learning_rate_schedule,
batch_size=flags_obj.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
# if mutliple epochs, ignore the train_steps flag. # if mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps: if train_epochs <= 1 and flags_obj.train_steps:
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to training performance."""
import tensorflow as tf
def configure_optimizer(optimizer,
use_float16=False,
use_graph_rewrite=False,
loss_scale="dynamic"):
"""Configures optimizer object with performance options."""
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=loss_scale))
if use_graph_rewrite:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
return optimizer
def set_mixed_precision_policy(dtype, loss_scale=None):
"""Sets mix precision policy."""
if dtype == tf.float16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_float16', loss_scale=loss_scale)
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.float32:
tf.keras.mixed_precision.experimental.set_policy('float32')
else:
raise ValueError("Unexpected dtype: %s" % dtype)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
"""Runs bert squad training."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=FLAGS.sp_model_file)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_sp)
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
...@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_labels): sentence_labels):
"""Implements call() for the layer.""" """Implements call() for the layer."""
lm_label_weights = tf.cast(lm_label_weights, tf.float32) lm_label_weights = tf.cast(lm_label_weights, tf.float32)
lm_output = tf.cast(lm_output, tf.float32)
sentence_output = tf.cast(sentence_output, tf.float32)
mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
......
...@@ -66,10 +66,6 @@ def define_common_bert_flags(): ...@@ -66,10 +66,6 @@ def define_common_bert_flags():
flags.DEFINE_string( flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. ' 'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.') 'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_enum(
'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
...@@ -92,9 +88,17 @@ def define_common_bert_flags(): ...@@ -92,9 +88,17 @@ def define_common_bert_flags():
) )
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
def use_float16(): def use_float16():
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16 return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
def use_graph_rewrite():
return flags.FLAGS.fp16_implementation == 'graph_rewrite'
def get_loss_scale(): def get_loss_scale():
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic') return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
...@@ -27,6 +27,7 @@ from absl import logging ...@@ -27,6 +27,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
...@@ -126,16 +127,12 @@ def run_bert_classifier(strategy, ...@@ -126,16 +127,12 @@ def run_bert_classifier(strategy,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)) hub_module_trainable=FLAGS.hub_module_trainable))
classifier_model.optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite': classifier_model.optimizer = performance.configure_optimizer(
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as optimizer,
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' use_float16=common_flags.use_float16(),
# which will ensure tf.compat.v2.keras.mixed_precision and use_graph_rewrite=common_flags.use_graph_rewrite())
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
classifier_model.optimizer)
return classifier_model, core_model return classifier_model, core_model
# During distributed training, loss used for gradient computation is # During distributed training, loss used for gradient computation is
...@@ -302,6 +299,7 @@ def run_bert(strategy, ...@@ -302,6 +299,7 @@ def run_bert(strategy,
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode) raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla) keras_utils.set_config_v2(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size'] train_data_size = input_meta_data['train_data_size']
......
...@@ -23,6 +23,7 @@ from absl import logging ...@@ -23,6 +23,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
...@@ -102,16 +103,12 @@ def run_customized_training(strategy, ...@@ -102,16 +103,12 @@ def run_customized_training(strategy,
"""Gets a pretraining model.""" """Gets a pretraining model."""
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite': pretrain_model.optimizer = performance.configure_optimizer(
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as optimizer,
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' use_float16=common_flags.use_float16(),
# which will ensure tf.compat.v2.keras.mixed_precision and use_graph_rewrite=common_flags.use_graph_rewrite())
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
pretrain_model.optimizer)
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
...@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy): ...@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy):
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
return run_customized_training( return run_customized_training(
strategy, strategy,
bert_config, bert_config,
......
...@@ -19,361 +19,44 @@ from __future__ import division ...@@ -19,361 +19,44 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import optimization
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline from official.nlp.bert import run_squad_helper
from official.nlp.bert import model_saving_utils
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None, flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.') 'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be printed. '
'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
common_flags.define_common_bert_flags() # More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
MODEL_CLASSES = {
'bert': (bert_configs.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
'albert': (albert_configs.AlbertConfig, squad_lib_sp,
tokenization.FullSentencePieceTokenizer),
}
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
return _loss_fn
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
squad_lib = MODEL_CLASSES[FLAGS.model_type][1]
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield squad_lib.RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy, def train_squad(strategy,
input_meta_data, input_meta_data,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=False): run_eagerly=False):
"""Run bert squad training.""" """Run bert squad training."""
if strategy: bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
logging.info('Training using customized training loop with distribution' run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
' strategy.') custom_callbacks, run_eagerly)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
def predict_squad(strategy, input_meta_data): def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset.""" """Makes predictions for a squad dataset."""
config_cls, squad_lib, tokenizer_cls = MODEL_CLASSES[FLAGS.model_type] bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
bert_config = config_cls.from_json_file(FLAGS.bert_config_file) tokenizer = tokenization.FullTokenizer(
if tokenizer_cls == tokenization.FullTokenizer: vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
tokenizer = tokenizer_cls( run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) bert_config, squad_lib_wp)
else:
assert tokenizer_cls == tokenization.FullSentencePieceTokenizer
tokenizer = tokenizer_cls(sp_model_file=FLAGS.sp_model_file)
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
def export_squad(model_export_path, input_meta_data): def export_squad(model_export_path, input_meta_data):
...@@ -386,16 +69,8 @@ def export_squad(model_export_path, input_meta_data): ...@@ -386,16 +69,8 @@ def export_squad(model_export_path, input_meta_data):
Raises: Raises:
Export path is not specified, got an empty string or None. Export path is not specified, got an empty string or None.
""" """
if not model_export_path: bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
raise ValueError('Export path is not specified: %s' % model_export_path) run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
def main(_): def main(_):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils
def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be '
'printed. A number of warnings are expected for a normal SQuAD '
'evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one '
'another.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
return _loss_fn
RawResult = collections.namedtuple('RawResult',
['unique_id', 'start_logits', 'end_logits'])
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
def export_squad(model_export_path, input_meta_data, bert_config):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
...@@ -490,10 +490,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -490,10 +490,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples, def write_predictions(all_examples,
all_features, all_features,
all_results, all_results,
......
...@@ -562,10 +562,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -562,10 +562,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples, def write_predictions(all_examples,
all_features, all_features,
all_results, all_results,
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
See README for description of setting the training schedule and evaluating the See README for description of setting the training schedule and evaluating the
BLEU score. BLEU score.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -30,19 +29,19 @@ from absl import flags ...@@ -30,19 +29,19 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-bad-import-order from official.modeling import performance
from official.nlp.transformer import compute_bleu from official.nlp.transformer import compute_bleu
from official.nlp.transformer.utils import tokenizer
from official.nlp.transformer import data_pipeline from official.nlp.transformer import data_pipeline
from official.nlp.transformer import metrics from official.nlp.transformer import metrics
from official.nlp.transformer import misc from official.nlp.transformer import misc
from official.nlp.transformer import optimizer from official.nlp.transformer import optimizer
from official.nlp.transformer import transformer from official.nlp.transformer import transformer
from official.nlp.transformer import translate from official.nlp.transformer import translate
from official.nlp.transformer.utils import tokenizer
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
INF = int(1e9) INF = int(1e9)
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
...@@ -180,21 +179,9 @@ class TransformerTask(object): ...@@ -180,21 +179,9 @@ class TransformerTask(object):
else: else:
logging.info("Not using any distribution strategy.") logging.info("Not using any distribution strategy.")
if params["dtype"] == tf.float16: performance.set_mixed_precision_policy(
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor params["dtype"],
# like this. What if multiple instances of TransformerTask are created? flags_core.get_loss_scale(flags_obj, default_for_fp16="dynamic"))
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale = flags_core.get_loss_scale(
flags_obj, default_for_fp16="dynamic")
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_float16", loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif params["dtype"] == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
"mixed_bfloat16")
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
@property @property
def use_tpu(self): def use_tpu(self):
...@@ -434,8 +421,6 @@ class TransformerTask(object): ...@@ -434,8 +421,6 @@ class TransformerTask(object):
def _create_optimizer(self): def _create_optimizer(self):
"""Creates optimizer.""" """Creates optimizer."""
params = self.params params = self.params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule = optimizer.LearningRateSchedule( lr_schedule = optimizer.LearningRateSchedule(
params["learning_rate"], params["hidden_size"], params["learning_rate"], params["hidden_size"],
params["learning_rate_warmup_steps"]) params["learning_rate_warmup_steps"])
...@@ -445,18 +430,12 @@ class TransformerTask(object): ...@@ -445,18 +430,12 @@ class TransformerTask(object):
params["optimizer_adam_beta2"], params["optimizer_adam_beta2"],
epsilon=params["optimizer_adam_epsilon"]) epsilon=params["optimizer_adam_epsilon"])
if params["dtype"] == tf.float16: opt = performance.configure_optimizer(
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer( opt,
opt, use_float16=params["dtype"] == tf.float16,
loss_scale=flags_core.get_loss_scale( use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
self.flags_obj, default_for_fp16="dynamic")) loss_scale=flags_core.get_loss_scale(
if self.flags_obj.fp16_implementation == "graph_rewrite": self.flags_obj, default_for_fp16="dynamic"))
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
...@@ -21,7 +21,6 @@ from __future__ import print_function ...@@ -21,7 +21,6 @@ from __future__ import print_function
import math import math
import unittest import unittest
import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -192,34 +191,34 @@ class NcfTest(tf.test.TestCase): ...@@ -192,34 +191,34 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self): def test_end_to_end_estimator(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS) extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self): def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self): def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat_ctl(self): def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS + flags = (self._BASE_END_TO_END_FLAGS +
...@@ -229,7 +228,7 @@ class NcfTest(tf.test.TestCase): ...@@ -229,7 +228,7 @@ class NcfTest(tf.test.TestCase):
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags) extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
...@@ -242,7 +241,7 @@ class NcfTest(tf.test.TestCase): ...@@ -242,7 +241,7 @@ class NcfTest(tf.test.TestCase):
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1', extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
'--dtype', 'fp16']) '--dtype', 'fp16'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
...@@ -256,7 +255,7 @@ class NcfTest(tf.test.TestCase): ...@@ -256,7 +255,7 @@ class NcfTest(tf.test.TestCase):
'--dtype', 'fp16', '--dtype', 'fp16',
'--keras_use_ctl']) '--keras_use_ctl'])
@mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_2_gpu_fp16(self): def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
......
...@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj): ...@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj):
def get_loss_scale(flags_obj, default_for_fp16): def get_loss_scale(flags_obj, default_for_fp16):
dtype = get_tf_dtype(flags_obj)
if flags_obj.loss_scale == "dynamic": if flags_obj.loss_scale == "dynamic":
return flags_obj.loss_scale return flags_obj.loss_scale
elif flags_obj.loss_scale is not None: elif flags_obj.loss_scale is not None:
return float(flags_obj.loss_scale) return float(flags_obj.loss_scale)
elif flags_obj.dtype == "fp32": elif dtype == tf.float32 or dtype == tf.bfloat16:
return 1 # No loss scaling is needed for fp32 return 1 # No loss scaling is needed for fp32
else: else:
assert flags_obj.dtype == "fp16" assert dtype == tf.float16
return default_for_fp16 return default_for_fp16
......
...@@ -20,7 +20,6 @@ from __future__ import print_function ...@@ -20,7 +20,6 @@ from __future__ import print_function
import os import os
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
...@@ -36,78 +35,6 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples ...@@ -36,78 +35,6 @@ LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
] ]
def learning_rate_schedule(current_epoch,
current_batch,
steps_per_epoch,
batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor.
Args:
current_epoch: integer, current epoch indexed from 0.
current_batch: integer, current batch in the current epoch, indexed from 0.
steps_per_epoch: integer, number of steps in an epoch.
batch_size: integer, total batch sized.
Returns:
Adjusted learning rate.
"""
initial_lr = BASE_LEARNING_RATE * batch_size / 256
epoch = current_epoch + float(current_batch) / steps_per_epoch
warmup_lr_multiplier, warmup_end_epoch = LR_SCHEDULE[0]
if epoch < warmup_end_epoch:
# Learning rate increases linearly per step.
return initial_lr * warmup_lr_multiplier * epoch / warmup_end_epoch
for mult, start_epoch in LR_SCHEDULE:
if epoch >= start_epoch:
learning_rate = initial_lr * mult
else:
break
return learning_rate
class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
"""Callback to update learning rate on every batch (not epoch boundaries).
N.B. Only support Keras optimizers, not TF optimizers.
Attributes:
schedule: a function that takes an epoch index and a batch index as input
(both integer, indexed from 0) and returns a new learning rate as
output (float).
"""
def __init__(self, schedule, batch_size, steps_per_epoch):
super(LearningRateBatchScheduler, self).__init__()
self.schedule = schedule
self.steps_per_epoch = steps_per_epoch
self.batch_size = batch_size
self.epochs = -1
self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'learning_rate'):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1
def on_batch_begin(self, batch, logs=None):
"""Executes before step begins."""
lr = self.schedule(self.epochs,
batch,
self.steps_per_epoch,
self.batch_size)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function should be float.')
if lr != self.prev_lr:
self.model.optimizer.learning_rate = lr # lr should be a float here
self.prev_lr = lr
tf.compat.v1.logging.debug(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.', self.epochs, batch, lr)
class PiecewiseConstantDecayWithWarmup( class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule): tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule.""" """Piecewise constant decay with warmup schedule."""
...@@ -180,10 +107,8 @@ def get_optimizer(learning_rate=0.1): ...@@ -180,10 +107,8 @@ def get_optimizer(learning_rate=0.1):
return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9) return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def get_callbacks( def get_callbacks(
steps_per_epoch, steps_per_epoch,
learning_rate_schedule_fn=None,
pruning_method=None, pruning_method=None,
enable_checkpoint_and_export=False, enable_checkpoint_and_export=False,
model_dir=None): model_dir=None):
...@@ -194,13 +119,6 @@ def get_callbacks( ...@@ -194,13 +119,6 @@ def get_callbacks(
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None) logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks = [time_callback] callbacks = [time_callback]
if not FLAGS.use_tensor_lr and learning_rate_schedule_fn:
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
steps_per_epoch=steps_per_epoch)
callbacks.append(lr_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir) log_dir=FLAGS.model_dir)
...@@ -317,7 +235,7 @@ def define_keras_flags( ...@@ -317,7 +235,7 @@ def define_keras_flags(
help='Whether to use a trivial Keras model.') help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True, flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.') help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=False, flags.DEFINE_boolean(name='use_tensor_lr', default=True,
help='Use learning rate tensor instead of a callback.') help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_tensorboard', default=False, name='enable_tensorboard', default=False,
......
...@@ -23,6 +23,7 @@ from absl import flags ...@@ -23,6 +23,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import performance
from official.staging.training import controller from official.staging.training import controller
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
...@@ -110,16 +111,7 @@ def run(flags_obj): ...@@ -110,16 +111,7 @@ def run(flags_obj):
keras_utils.set_session_config( keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager, enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.float16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_float16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
# This only affects GPU. # This only affects GPU.
common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
......
...@@ -28,6 +28,7 @@ import tensorflow as tf ...@@ -28,6 +28,7 @@ import tensorflow as tf
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
from official.benchmark.models import trivial_model from official.benchmark.models import trivial_model
from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -65,17 +66,9 @@ def run(flags_obj): ...@@ -65,17 +66,9 @@ def run(flags_obj):
common.set_cudnn_batchnorm_mode() common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == tf.float16: performance.set_mixed_precision_policy(
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) flags_core.get_tf_dtype(flags_obj),
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
'mixed_float16', loss_scale=loss_scale)
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
if not keras_utils.is_v2_0():
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
elif dtype == tf.bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
...@@ -155,23 +148,19 @@ def run(flags_obj): ...@@ -155,23 +148,19 @@ def run(flags_obj):
dtype=dtype, dtype=dtype,
drop_remainder=drop_remainder) drop_remainder=drop_remainder)
lr_schedule = 0.1 lr_schedule = common.PiecewiseConstantDecayWithWarmup(
if flags_obj.use_tensor_lr: batch_size=flags_obj.batch_size,
lr_schedule = common.PiecewiseConstantDecayWithWarmup( epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
batch_size=flags_obj.batch_size, warmup_epochs=common.LR_SCHEDULE[0][1],
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
warmup_epochs=common.LR_SCHEDULE[0][1], multipliers=list(p[0] for p in common.LR_SCHEDULE),
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), compute_lr_on_cpu=True)
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
steps_per_epoch = ( steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
learning_rate_schedule_fn = None
with strategy_scope: with strategy_scope:
if flags_obj.optimizer == 'resnet50_default': if flags_obj.optimizer == 'resnet50_default':
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
learning_rate_schedule_fn = common.learning_rate_schedule
elif flags_obj.optimizer == 'mobilenet_default': elif flags_obj.optimizer == 'mobilenet_default':
initial_learning_rate = \ initial_learning_rate = \
flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
...@@ -248,7 +237,6 @@ def run(flags_obj): ...@@ -248,7 +237,6 @@ def run(flags_obj):
callbacks = common.get_callbacks( callbacks = common.get_callbacks(
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
learning_rate_schedule_fn=learning_rate_schedule_fn,
pruning_method=flags_obj.pruning_method, pruning_method=flags_obj.pruning_method,
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir) model_dir=flags_obj.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment