Unverified Commit 7b5606a5 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Added thread tuning and tweaked tests to improve Keras model performance (#6396)

parent dba24007
......@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import time
import numpy as np
......@@ -160,6 +162,31 @@ def set_config_v2():
)
def set_gpu_thread_mode_and_count(flags_obj):
"""Set GPU thread mode and count, and adjust dataset threads count."""
cpu_count = multiprocessing.cpu_count()
tf.compat.v1.logging.info('Logical CPU cores: %s', cpu_count)
# Allocate private thread pool for each GPU to schedule and launch kernels
per_gpu_thread_count = flags_obj.per_gpu_thread_count or 2
os.environ['TF_GPU_THREAD_MODE'] = flags_obj.tf_gpu_thread_mode
os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
tf.compat.v1.logging.info('TF_GPU_THREAD_COUNT: %s',
os.environ['TF_GPU_THREAD_COUNT'])
tf.compat.v1.logging.info('TF_GPU_THREAD_MODE: %s',
os.environ['TF_GPU_THREAD_MODE'])
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
# private threads and memory copy threads.
total_gpu_thread_count = per_gpu_thread_count * flags_obj.num_gpus
num_mem_copy_threads = flags_obj.num_gpus
if not flags_obj.datasets_num_private_threads:
flags_obj.datasets_num_private_threads = (cpu_count - total_gpu_thread_count
- num_mem_copy_threads)
tf.compat.v1.logging.info('Set datasets_num_private_threads to %s',
flags_obj.datasets_num_private_threads)
def get_optimizer():
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
......
......@@ -297,6 +297,20 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 256 * 8 # 8 GPUs
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_tweaked(self):
"""Test Keras model with manual config tuning, XLA, 8 GPUs and fp16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.dtype = 'fp16'
FLAGS.enable_eager = True
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_graph_8_gpu(self):
"""Test Keras model in legacy graph mode with 8 GPUs."""
self._setup()
......
......@@ -28,6 +28,7 @@ from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
......@@ -102,6 +103,10 @@ def run(flags_obj):
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_common.set_gpu_thread_mode_and_count(flags_obj)
dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'float16':
policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
......@@ -135,13 +140,15 @@ def run(flags_obj):
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
dtype=dtype)
eval_input_dataset = None
if not flags_obj.skip_eval:
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
dtype=dtype)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
......@@ -209,6 +216,7 @@ def run(flags_obj):
def main(_):
model_helpers.apply_clean(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS)
......
......@@ -156,6 +156,13 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"Whether and how the GPU device uses its own threadpool.")
)
flags.DEFINE_integer(
name="per_gpu_thread_count", short_name="pgtc", default=0,
help=help_wrap(
"The number of threads to use for GPU. Only valid when "
"tf_gpu_thread_mode is not global.")
)
if datasets_num_private_threads:
flags.DEFINE_integer(
name="datasets_num_private_threads",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment