Commit 4577d2c9 authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 292606219
parent 70e14f03
...@@ -35,6 +35,7 @@ from official.vision.detection.dataloader import input_reader ...@@ -35,6 +35,7 @@ from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory from official.vision.detection.modeling import factory as model_factory
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
hyperparams_flags.initialize_common_flags() hyperparams_flags.initialize_common_flags()
...@@ -68,7 +69,8 @@ FLAGS = flags.FLAGS ...@@ -68,7 +69,8 @@ FLAGS = flags.FLAGS
def run_executor(params, def run_executor(params,
train_input_fn=None, train_input_fn=None,
eval_input_fn=None, eval_input_fn=None,
callbacks=None): callbacks=None,
strategy=None):
"""Runs Retinanet model on distribution strategy defined by the user.""" """Runs Retinanet model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16: if params.architecture.use_bfloat16:
...@@ -78,35 +80,44 @@ def run_executor(params, ...@@ -78,35 +80,44 @@ def run_executor(params,
model_builder = model_factory.model_generator(params) model_builder = model_factory.model_generator(params)
if strategy is None:
strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
num_workers = int(strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if FLAGS.mode == 'train': if FLAGS.mode == 'train':
def _model_fn(params): def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.TRAIN) return model_builder.build_model(params, mode=ModeKeys.TRAIN)
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
logging.info( logging.info(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers, is_multi_host) strategy.num_replicas_in_sync, num_workers, is_multi_host)
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
batch_size=params.train.batch_size //
builder.strategy.num_replicas_in_sync)
dist_executor = builder.build_executor( dist_executor = DetectionDistributedExecutor(
class_ctor=DetectionDistributedExecutor, strategy=strategy,
params=params, params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn, model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn, loss_fn=model_builder.build_loss_fn,
is_multi_host=is_multi_host,
predict_post_process_fn=model_builder.post_processing, predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn()) .make_filter_trainable_variables_fn())
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
batch_size=params.train.batch_size // strategy.num_replicas_in_sync)
return dist_executor.train( return dist_executor.train(
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
model_dir=params.model_dir, model_dir=params.model_dir,
...@@ -115,30 +126,26 @@ def run_executor(params, ...@@ -115,30 +126,26 @@ def run_executor(params,
init_checkpoint=model_builder.make_restore_checkpoint_fn(), init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks, custom_callbacks=callbacks,
save_config=True) save_config=True)
elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once': elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':
def _model_fn(params): def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
builder = executor.ExecutorBuilder( logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
strategy_type=params.strategy_type, strategy.num_replicas_in_sync, num_workers, is_multi_host)
strategy_config=params.strategy_config)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if is_multi_host: if is_multi_host:
eval_input_fn = functools.partial( eval_input_fn = functools.partial(
eval_input_fn, eval_input_fn,
batch_size=params.eval.batch_size // batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)
builder.strategy.num_replicas_in_sync)
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', dist_executor = DetectionDistributedExecutor(
builder.strategy.num_replicas_in_sync, num_workers, strategy=strategy,
is_multi_host)
dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor,
params=params, params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn, model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn, loss_fn=model_builder.build_loss_fn,
is_multi_host=is_multi_host,
predict_post_process_fn=model_builder.post_processing, predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn()) .make_filter_trainable_variables_fn())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment