Commit d0454314 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328382317
parent 543ee2c4
...@@ -48,6 +48,7 @@ FLAGS = flags.FLAGS ...@@ -48,6 +48,7 @@ FLAGS = flags.FLAGS
def _get_classifier_parameters( def _get_classifier_parameters(
model_variant: Optional[str] = None,
num_gpus: int = 0, num_gpus: int = 0,
builder: str = 'records', builder: str = 'records',
skip_eval: bool = False, skip_eval: bool = False,
...@@ -65,7 +66,7 @@ def _get_classifier_parameters( ...@@ -65,7 +66,7 @@ def _get_classifier_parameters(
report_metrics: bool = True, report_metrics: bool = True,
batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]: batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
"""Gets classifier trainer's ResNet parameters.""" """Gets classifier trainer's ResNet parameters."""
return { params = {
'runtime': { 'runtime': {
'num_gpus': num_gpus, 'num_gpus': num_gpus,
'distribution_strategy': distribution_strategy, 'distribution_strategy': distribution_strategy,
...@@ -110,6 +111,11 @@ def _get_classifier_parameters( ...@@ -110,6 +111,11 @@ def _get_classifier_parameters(
'skip_eval': skip_eval, 'skip_eval': skip_eval,
}, },
} }
if model_variant is not None:
params['model']['model_params'] = {
'model_name': model_variant,
}
return params
class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
...@@ -323,6 +329,7 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -323,6 +329,7 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
def _run_and_report_benchmark( def _run_and_report_benchmark(
self, self,
experiment_name: str, experiment_name: str,
model_variant: Optional[str] = None,
skip_steps: Optional[int] = None, skip_steps: Optional[int] = None,
top_1_min: float = MIN_TOP_1_ACCURACY, top_1_min: float = MIN_TOP_1_ACCURACY,
top_1_max: float = MAX_TOP_1_ACCURACY, top_1_max: float = MAX_TOP_1_ACCURACY,
...@@ -344,6 +351,7 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -344,6 +351,7 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
FLAGS.model_dir = self._get_model_dir(experiment_name) FLAGS.model_dir = self._get_model_dir(experiment_name)
parameters = _get_classifier_parameters( parameters = _get_classifier_parameters(
model_variant=model_variant,
builder=self.dataset_builder, builder=self.dataset_builder,
skip_eval=True, skip_eval=True,
num_gpus=num_gpus, num_gpus=num_gpus,
...@@ -1147,6 +1155,16 @@ class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase): ...@@ -1147,6 +1155,16 @@ class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase):
tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110, tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir) data_dir=data_dir)
def benchmark_2x2_tpu_b7_bf16(self):
self._setup()
self._run_and_report_benchmark(
experiment_name='benchmark_b7_2x2_tpu_bf16',
model_variant='efficientnet-b7',
dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase): class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
"""Resnet50 real data (stored in remote storage) benchmark tests.""" """Resnet50 real data (stored in remote storage) benchmark tests."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment