Commit 338a0fc2 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307641075
parent 9c1887a8
...@@ -608,26 +608,36 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -608,26 +608,36 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
loss_scale='dynamic', loss_scale='dynamic',
dataset_num_private_threads=48) dataset_num_private_threads=48)
def benchmark_2x2_tpu_fp16(self): def benchmark_2x2_tpu_bf16(self):
"""Test Keras model with 2x2 TPU, fp16.""" """Test Keras model with 2x2 TPU, bf16."""
self._setup() self._setup()
self._run_and_report_benchmark( self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_fp16', experiment_name='benchmark_2x2_tpu_bf16',
dtype='bfloat16', dtype='bfloat16',
num_tpus=8, num_tpus=8,
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_4x4_tpu_fp16(self): def benchmark_4x4_tpu_bf16(self):
"""Test Keras model with 4x4 TPU, fp16.""" """Test Keras model with 4x4 TPU, bf16."""
self._setup() self._setup()
self._run_and_report_benchmark( self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_fp16', experiment_name='benchmark_4x4_tpu_bf16',
dtype='bfloat16', dtype='bfloat16',
num_tpus=32, num_tpus=32,
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16."""
self._setup()
self._run_and_report_benchmark(
experiment_name='benchmark_8x8_tpu_bf16',
dtype='bfloat16',
num_tpus=128,
distribution_strategy='tpu',
per_replica_batch_size=64)
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object( super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object(
stats, stats,
...@@ -1031,26 +1041,36 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1031,26 +1041,36 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_2x2_tpu_fp16(self): def benchmark_2x2_tpu_bf16(self):
"""Test Keras model with 2x2 TPU, fp16.""" """Test Keras model with 2x2 TPU, bf16."""
self._setup() self._setup()
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.distribution_strategy = 'tpu' FLAGS.distribution_strategy = 'tpu'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_bf16')
FLAGS.batch_size = 1024 FLAGS.batch_size = 1024
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_4x4_tpu_fp16(self): def benchmark_4x4_tpu_bf16(self):
"""Test Keras model with 4x4 TPU, fp16.""" """Test Keras model with 4x4 TPU, bf16."""
self._setup() self._setup()
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
FLAGS.distribution_strategy = 'tpu' FLAGS.distribution_strategy = 'tpu'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_bf16')
FLAGS.batch_size = 4096 FLAGS.batch_size = 4096
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16."""
self._setup()
FLAGS.dtype = 'bf16'
FLAGS.distribution_strategy = 'tpu'
FLAGS.model_dir = self._get_model_dir('benchmark_8x8_tpu_bf16')
FLAGS.batch_size = 8192
self._run_and_report_benchmark()
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50KerasBenchmarkBase, self).fill_report_object( super(Resnet50KerasBenchmarkBase, self).fill_report_object(
stats, stats,
......
...@@ -38,11 +38,10 @@ class CtlBenchmark(PerfZeroBenchmark): ...@@ -38,11 +38,10 @@ class CtlBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing.""" """Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None): def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
self.output_dir = output_dir
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
super(CtlBenchmark, self).__init__( super(CtlBenchmark, self).__init__(
output_dir=self.output_dir, output_dir=output_dir,
default_flags=self.default_flags, default_flags=self.default_flags,
flag_methods=self.flag_methods) flag_methods=self.flag_methods)
...@@ -186,9 +185,6 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -186,9 +185,6 @@ class Resnet50CtlAccuracy(CtlBenchmark):
log_steps=100, log_steps=100,
start_time_sec=start_time_sec) start_time_sec=start_time_sec)
def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name)
class Resnet50CtlBenchmarkBase(CtlBenchmark): class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks.""" """Resnet50 benchmarks."""
...@@ -207,16 +203,14 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -207,16 +203,14 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
stats = resnet_ctl_imagenet_main.run(FLAGS) stats = resnet_ctl_imagenet_main.run(FLAGS)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
# Number of logged step time entries that are excluded in performance # Warmup means the number of logged step time entries that are excluded in
# report. We keep results from last 100 batches in this case. # performance report. Default to exclude 1 FLAGS.log_steps time.
warmup = (FLAGS.train_steps - 100) // FLAGS.log_steps
super(Resnet50CtlBenchmarkBase, self)._report_benchmark( super(Resnet50CtlBenchmarkBase, self)._report_benchmark(
stats, stats,
wall_time_sec, wall_time_sec,
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
warmup=warmup, warmup=1,
start_time_sec=start_time_sec) start_time_sec=start_time_sec)
def benchmark_1_gpu_no_dist_strat(self): def benchmark_1_gpu_no_dist_strat(self):
...@@ -373,6 +367,41 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -373,6 +367,41 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def _set_df_common(self):
FLAGS.steps_per_loop = 500
FLAGS.train_epochs = 2
FLAGS.train_steps = None
FLAGS.skip_eval = True
FLAGS.enable_eager = True
FLAGS.enable_tensorboard = False
FLAGS.distribution_strategy = 'tpu'
FLAGS.report_accuracy_metrics = False
FLAGS.log_steps = 50
FLAGS.single_l2_loss_op = True
FLAGS.use_tf_function = True
FLAGS.enable_checkpoint_and_export = False
def benchmark_2x2_tpu_bf16(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 1024
FLAGS.dtype = 'bf16'
self._run_and_report_benchmark()
def benchmark_4x4_tpu_bf16(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16'
self._run_and_report_benchmark()
def benchmark_8x16_tpu_bf16(self):
self._setup()
self._set_df_common()
FLAGS.batch_size = 8192
FLAGS.dtype = 'bf16'
self._run_and_report_benchmark()
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50CtlBenchmarkBase, self).fill_report_object( super(Resnet50CtlBenchmarkBase, self).fill_report_object(
stats, total_batch_size=FLAGS.batch_size, log_steps=FLAGS.log_steps) stats, total_batch_size=FLAGS.batch_size, log_steps=FLAGS.log_steps)
......
...@@ -163,7 +163,7 @@ def run(flags_obj): ...@@ -163,7 +163,7 @@ def run(flags_obj):
resnet_controller = controller.Controller( resnet_controller = controller.Controller(
strategy, strategy,
runnable.train, runnable.train,
runnable.evaluate, runnable.evaluate if not flags_obj.skip_eval else None,
global_step=runnable.global_step, global_step=runnable.global_step,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
train_steps=per_epoch_steps * train_epochs, train_steps=per_epoch_steps * train_epochs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment