Commit 0cc16aa5 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307462714
parent 0d968ea2
...@@ -320,6 +320,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -320,6 +320,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
top_1_min: float = MIN_TOP_1_ACCURACY, top_1_min: float = MIN_TOP_1_ACCURACY,
top_1_max: float = MAX_TOP_1_ACCURACY, top_1_max: float = MAX_TOP_1_ACCURACY,
num_gpus: int = 0, num_gpus: int = 0,
num_tpus: int = 0,
distribution_strategy: str = 'mirrored', distribution_strategy: str = 'mirrored',
per_replica_batch_size: int = 128, per_replica_batch_size: int = 128,
epochs_between_evals: int = 1, epochs_between_evals: int = 1,
...@@ -350,7 +351,10 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -350,7 +351,10 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
dataset_num_private_threads=dataset_num_private_threads, dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale) loss_scale=loss_scale)
FLAGS.params_override = json.dumps(parameters) FLAGS.params_override = json.dumps(parameters)
total_batch_size = num_gpus * per_replica_batch_size if distribution_strategy == 'tpu':
total_batch_size = num_tpus * per_replica_batch_size
else:
total_batch_size = num_gpus * per_replica_batch_size
start_time_sec = time.time() start_time_sec = time.time()
stats = classifier_trainer.run(flags.FLAGS) stats = classifier_trainer.run(flags.FLAGS)
...@@ -610,6 +614,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -610,6 +614,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark( self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_fp16', experiment_name='benchmark_2x2_tpu_fp16',
dtype='bfloat16', dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
...@@ -619,6 +624,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -619,6 +624,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
self._run_and_report_benchmark( self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_fp16', experiment_name='benchmark_4x4_tpu_fp16',
dtype='bfloat16', dtype='bfloat16',
num_tpus=32,
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment