"...models/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "08743385d93a0aa4145da6f6db49bfa00df148a3"
Commit bd73fdfe authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Enabled performance related parameters for Transformer: all_reduce_alg,

enable_eager, tf_gpu_thread_mode and datasets_num_private_threads

PiperOrigin-RevId: 282804724
parent 808b2f5e
...@@ -71,9 +71,6 @@ def define_transformer_flags(): ...@@ -71,9 +71,6 @@ def define_transformer_flags():
dtype=True, dtype=True,
loss_scale=True, loss_scale=True,
all_reduce_alg=True, all_reduce_alg=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True, enable_xla=True,
force_v2_in_keras_compile=True, force_v2_in_keras_compile=True,
fp16_implementation=True fp16_implementation=True
...@@ -89,14 +86,10 @@ def define_transformer_flags(): ...@@ -89,14 +86,10 @@ def define_transformer_flags():
'convolutions and batch normalizations, and this flag allows to ' 'convolutions and batch normalizations, and this flag allows to '
'disable it.' 'disable it.'
) )
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_device(tpu=True) flags_core.define_device(tpu=True)
flags.DEFINE_boolean(
name='enable_eager', default=False,
help='Enable eager mode? (Note: this is NOT run eagerly / op-by-op mode)')
flags.DEFINE_integer( flags.DEFINE_integer(
name='train_steps', short_name='ts', default=300000, name='train_steps', short_name='ts', default=300000,
help=flags_core.help_wrap('The number of steps used to train.')) help=flags_core.help_wrap('The number of steps used to train.'))
......
...@@ -43,7 +43,6 @@ from official.utils.flags import core as flags_core ...@@ -43,7 +43,6 @@ from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.vision.image_classification import common as image_common
INF = int(1e9) INF = int(1e9)
...@@ -165,8 +164,6 @@ class TransformerTask(object): ...@@ -165,8 +164,6 @@ class TransformerTask(object):
self.distribution_strategy = distribution_utils.get_distribution_strategy( self.distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=num_gpus, num_gpus=num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu or "") tpu_address=flags_obj.tpu or "")
if self.use_tpu: if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
...@@ -210,16 +207,8 @@ class TransformerTask(object): ...@@ -210,16 +207,8 @@ class TransformerTask(object):
flags_obj = self.flags_obj flags_obj = self.flags_obj
# Sets config options. # Sets config options.
keras_utils.set_session_config( keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
# Execute flag override logic for better model performance
# Use the set_gpu_thread_mode_and_count function from
# vision.image_classification, which is universal and not specific to
# vision or image_classification
if flags_obj.tf_gpu_thread_mode:
image_common.set_gpu_thread_mode_and_count(flags_obj)
_ensure_dir(flags_obj.model_dir) _ensure_dir(flags_obj.model_dir)
with distribution_utils.get_strategy_scope(self.distribution_strategy): with distribution_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True) model = transformer.create_model(params, is_train=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment