Commit 18df1a09 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Enabled performance related parameters for Transformer: all_reduce_alg,

enable_eager, tf_gpu_thread_mode and datasets_num_private_threads

PiperOrigin-RevId: 282797514
parent ae989abc
......@@ -71,6 +71,9 @@ def define_transformer_flags():
dtype=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
force_v2_in_keras_compile=True,
fp16_implementation=True
......@@ -90,6 +93,10 @@ def define_transformer_flags():
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
flags.DEFINE_boolean(
name='enable_eager', default=False,
help='Enable eager mode? (Note: this is NOT run eagerly / op-by-op mode)')
flags.DEFINE_integer(
name='train_steps', short_name='ts', default=300000,
help=flags_core.help_wrap('The number of steps used to train.'))
......
......@@ -43,6 +43,7 @@ from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils
from official.vision.image_classification import common as image_common
INF = int(1e9)
......@@ -164,6 +165,8 @@ class TransformerTask(object):
self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu or "")
if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
......@@ -207,8 +210,16 @@ class TransformerTask(object):
flags_obj = self.flags_obj
# Sets config options.
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
# Execute flag override logic for better model performance
# Use the set_gpu_thread_mode_and_count function from
# vision.image_classification, which is universal and not specific to
# vision or image_classification
if flags_obj.tf_gpu_thread_mode:
image_common.set_gpu_thread_mode_and_count(flags_obj)
_ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment