Commit e91779b8 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Classifier trainer GPU throughput fixes.

PiperOrigin-RevId: 309776126
parent 158c96d2
......@@ -61,7 +61,8 @@ def _get_classifier_parameters(
run_eagerly: bool = False,
gpu_thread_mode: Optional[str] = None,
dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None) -> MutableMapping[str, Any]:
loss_scale: Optional[str] = None,
batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
"""Gets classifier trainer's ResNet parameters."""
return {
'runtime': {
......@@ -72,6 +73,7 @@ def _get_classifier_parameters(
'dataset_num_private_threads': dataset_num_private_threads,
'gpu_thread_mode': gpu_thread_mode,
'loss_scale': loss_scale,
'batchnorm_spatial_persistent': batchnorm_spatial_persistent,
},
'train_dataset': {
'builder': builder,
......@@ -167,7 +169,8 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
run_eagerly=run_eagerly,
gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale)
loss_scale=loss_scale,
batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters)
total_batch_size = num_gpus * per_replica_batch_size
......@@ -349,7 +352,8 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
enable_xla=enable_xla,
gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale)
loss_scale=loss_scale,
batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters)
if distribution_strategy == 'tpu':
total_batch_size = num_tpus * per_replica_batch_size
......
......@@ -273,6 +273,8 @@ class RuntimeConfig(Config):
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = 'mirrored'
......@@ -288,6 +290,7 @@ class RuntimeConfig(Config):
num_packs: int = 1
loss_scale: Optional[str] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
@dataclasses.dataclass
......
......@@ -228,13 +228,6 @@ def initialize(params: base_configs.ExperimentConfig,
"""Initializes backend related initializations."""
keras_utils.set_session_config(
enable_xla=params.runtime.enable_xla)
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params))
if tf.config.list_physical_devices('GPU'):
......@@ -248,6 +241,15 @@ def initialize(params: base_configs.ExperimentConfig,
if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True)
if tf.config.list_physical_devices('GPU'):
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) # pylint:disable=line-too-long
if params.runtime.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
def define_classifier_flags():
......
......@@ -4,6 +4,7 @@
runtime:
distribution_strategy: 'mirrored'
num_gpus: 1
batchnorm_spatial_persistent: True
train_dataset:
name: 'imagenet2012'
data_dir: null
......@@ -12,9 +13,9 @@ train_dataset:
image_size: 224
num_classes: 1000
num_examples: 1281167
batch_size: 128
batch_size: 64
use_per_replica_batch_size: True
dtype: 'float32'
dtype: 'float16'
mean_subtract: True
standardize: True
validation_dataset:
......@@ -25,9 +26,9 @@ validation_dataset:
image_size: 224
num_classes: 1000
num_examples: 50000
batch_size: 128
batch_size: 64
use_per_replica_batch_size: True
dtype: 'float32'
dtype: 'float16'
mean_subtract: True
standardize: True
model:
......
......@@ -98,10 +98,6 @@ class DatasetConfig(base_config.Config):
file_shuffle_buffer_size: The buffer size used for shuffling raw training
files.
skip_decoding: Whether to skip image decoding when loading from TFDS.
deterministic_train: Whether the examples in the training set should output
in a deterministic order.
use_slack: whether to introduce slack in the last prefetch. This may reduce
CPU contention at the start of a training step.
cache: whether to cache to dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
mean_subtract: whether or not to apply mean subtraction to the dataset.
......@@ -126,8 +122,6 @@ class DatasetConfig(base_config.Config):
shuffle_buffer_size: int = 10000
file_shuffle_buffer_size: int = 1024
skip_decoding: bool = True
deterministic_train: bool = False
use_slack: bool = True
cache: bool = False
mean_subtract: bool = False
standardize: bool = False
......@@ -452,16 +446,6 @@ class DatasetBuilder:
dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training)
if self.is_training:
options = tf.data.Options()
options.experimental_deterministic = self.config.deterministic_train
options.experimental_slack = self.config.use_slack
options.experimental_optimization.parallel_batch = True
options.experimental_optimization.map_fusion = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
# Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment