Commit e91779b8 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Classifier trainer GPU throughput fixes.

PiperOrigin-RevId: 309776126
parent 158c96d2
...@@ -61,7 +61,8 @@ def _get_classifier_parameters( ...@@ -61,7 +61,8 @@ def _get_classifier_parameters(
run_eagerly: bool = False, run_eagerly: bool = False,
gpu_thread_mode: Optional[str] = None, gpu_thread_mode: Optional[str] = None,
dataset_num_private_threads: Optional[int] = None, dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None) -> MutableMapping[str, Any]: loss_scale: Optional[str] = None,
batchnorm_spatial_persistent: bool = False) -> MutableMapping[str, Any]:
"""Gets classifier trainer's ResNet parameters.""" """Gets classifier trainer's ResNet parameters."""
return { return {
'runtime': { 'runtime': {
...@@ -72,6 +73,7 @@ def _get_classifier_parameters( ...@@ -72,6 +73,7 @@ def _get_classifier_parameters(
'dataset_num_private_threads': dataset_num_private_threads, 'dataset_num_private_threads': dataset_num_private_threads,
'gpu_thread_mode': gpu_thread_mode, 'gpu_thread_mode': gpu_thread_mode,
'loss_scale': loss_scale, 'loss_scale': loss_scale,
'batchnorm_spatial_persistent': batchnorm_spatial_persistent,
}, },
'train_dataset': { 'train_dataset': {
'builder': builder, 'builder': builder,
...@@ -167,7 +169,8 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -167,7 +169,8 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
run_eagerly=run_eagerly, run_eagerly=run_eagerly,
gpu_thread_mode=gpu_thread_mode, gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads, dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale) loss_scale=loss_scale,
batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters) FLAGS.params_override = json.dumps(parameters)
total_batch_size = num_gpus * per_replica_batch_size total_batch_size = num_gpus * per_replica_batch_size
...@@ -349,7 +352,8 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -349,7 +352,8 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
enable_xla=enable_xla, enable_xla=enable_xla,
gpu_thread_mode=gpu_thread_mode, gpu_thread_mode=gpu_thread_mode,
dataset_num_private_threads=dataset_num_private_threads, dataset_num_private_threads=dataset_num_private_threads,
loss_scale=loss_scale) loss_scale=loss_scale,
batchnorm_spatial_persistent=True)
FLAGS.params_override = json.dumps(parameters) FLAGS.params_override = json.dumps(parameters)
if distribution_strategy == 'tpu': if distribution_strategy == 'tpu':
total_batch_size = num_tpus * per_replica_batch_size total_batch_size = num_tpus * per_replica_batch_size
......
...@@ -273,6 +273,8 @@ class RuntimeConfig(Config): ...@@ -273,6 +273,8 @@ class RuntimeConfig(Config):
loss_scale: The type of loss scale. This is used when setting the mixed loss_scale: The type of loss scale. This is used when setting the mixed
precision policy. precision policy.
run_eagerly: Whether or not to run the experiment eagerly. run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
""" """
distribution_strategy: str = 'mirrored' distribution_strategy: str = 'mirrored'
...@@ -288,6 +290,7 @@ class RuntimeConfig(Config): ...@@ -288,6 +290,7 @@ class RuntimeConfig(Config):
num_packs: int = 1 num_packs: int = 1
loss_scale: Optional[str] = None loss_scale: Optional[str] = None
run_eagerly: bool = False run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -228,13 +228,6 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -228,13 +228,6 @@ def initialize(params: base_configs.ExperimentConfig,
"""Initializes backend related initializations.""" """Initializes backend related initializations."""
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=params.runtime.enable_xla) enable_xla=params.runtime.enable_xla)
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
performance.set_mixed_precision_policy(dataset_builder.dtype, performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params)) get_loss_scale(params))
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
...@@ -248,6 +241,15 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -248,6 +241,15 @@ def initialize(params: base_configs.ExperimentConfig,
if params.runtime.run_eagerly: if params.runtime.run_eagerly:
# Enable eager execution to allow step-by-step debugging # Enable eager execution to allow step-by-step debugging
tf.config.experimental_run_functions_eagerly(True) tf.config.experimental_run_functions_eagerly(True)
if tf.config.list_physical_devices('GPU'):
if params.runtime.gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) # pylint:disable=line-too-long
if params.runtime.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
def define_classifier_flags(): def define_classifier_flags():
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
runtime: runtime:
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
num_gpus: 1 num_gpus: 1
batchnorm_spatial_persistent: True
train_dataset: train_dataset:
name: 'imagenet2012' name: 'imagenet2012'
data_dir: null data_dir: null
...@@ -12,9 +13,9 @@ train_dataset: ...@@ -12,9 +13,9 @@ train_dataset:
image_size: 224 image_size: 224
num_classes: 1000 num_classes: 1000
num_examples: 1281167 num_examples: 1281167
batch_size: 128 batch_size: 64
use_per_replica_batch_size: True use_per_replica_batch_size: True
dtype: 'float32' dtype: 'float16'
mean_subtract: True mean_subtract: True
standardize: True standardize: True
validation_dataset: validation_dataset:
...@@ -25,9 +26,9 @@ validation_dataset: ...@@ -25,9 +26,9 @@ validation_dataset:
image_size: 224 image_size: 224
num_classes: 1000 num_classes: 1000
num_examples: 50000 num_examples: 50000
batch_size: 128 batch_size: 64
use_per_replica_batch_size: True use_per_replica_batch_size: True
dtype: 'float32' dtype: 'float16'
mean_subtract: True mean_subtract: True
standardize: True standardize: True
model: model:
......
...@@ -98,10 +98,6 @@ class DatasetConfig(base_config.Config): ...@@ -98,10 +98,6 @@ class DatasetConfig(base_config.Config):
file_shuffle_buffer_size: The buffer size used for shuffling raw training file_shuffle_buffer_size: The buffer size used for shuffling raw training
files. files.
skip_decoding: Whether to skip image decoding when loading from TFDS. skip_decoding: Whether to skip image decoding when loading from TFDS.
deterministic_train: Whether the examples in the training set should output
in a deterministic order.
use_slack: whether to introduce slack in the last prefetch. This may reduce
CPU contention at the start of a training step.
cache: whether to cache to dataset examples. Can be used to avoid re-reading cache: whether to cache to dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead. from disk on the second epoch. Requires significant memory overhead.
mean_subtract: whether or not to apply mean subtraction to the dataset. mean_subtract: whether or not to apply mean subtraction to the dataset.
...@@ -126,8 +122,6 @@ class DatasetConfig(base_config.Config): ...@@ -126,8 +122,6 @@ class DatasetConfig(base_config.Config):
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
file_shuffle_buffer_size: int = 1024 file_shuffle_buffer_size: int = 1024
skip_decoding: bool = True skip_decoding: bool = True
deterministic_train: bool = False
use_slack: bool = True
cache: bool = False cache: bool = False
mean_subtract: bool = False mean_subtract: bool = False
standardize: bool = False standardize: bool = False
...@@ -452,16 +446,6 @@ class DatasetBuilder: ...@@ -452,16 +446,6 @@ class DatasetBuilder:
dataset = dataset.batch(self.global_batch_size, dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training) drop_remainder=self.is_training)
if self.is_training:
options = tf.data.Options()
options.experimental_deterministic = self.config.deterministic_train
options.experimental_slack = self.config.use_slack
options.experimental_optimization.parallel_batch = True
options.experimental_optimization.map_fusion = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
# Prefetch overlaps in-feed with training # Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment