"vscode:/vscode.git/clone" did not exist on "5f9cf110541c5fd3f6cb70e53effeaef151a9443"
Commit 4577d2c9 authored by Rajagopal Ananthanarayanan's avatar Rajagopal Ananthanarayanan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 292606219
parent 70e14f03
......@@ -35,6 +35,7 @@ from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
hyperparams_flags.initialize_common_flags()
......@@ -68,7 +69,8 @@ FLAGS = flags.FLAGS
def run_executor(params,
train_input_fn=None,
eval_input_fn=None,
callbacks=None):
callbacks=None,
strategy=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16:
......@@ -78,35 +80,44 @@ def run_executor(params,
model_builder = model_factory.model_generator(params)
if strategy is None:
strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
num_packs=strategy_config.num_packs,
tpu_address=strategy_config.tpu)
num_workers = int(strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
if FLAGS.mode == 'train':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.TRAIN)
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
logging.info(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers, is_multi_host)
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
batch_size=params.train.batch_size //
builder.strategy.num_replicas_in_sync)
strategy.num_replicas_in_sync, num_workers, is_multi_host)
dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor,
dist_executor = DetectionDistributedExecutor(
strategy=strategy,
params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
is_multi_host=is_multi_host,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
if is_multi_host:
train_input_fn = functools.partial(
train_input_fn,
batch_size=params.train.batch_size // strategy.num_replicas_in_sync)
return dist_executor.train(
train_input_fn=train_input_fn,
model_dir=params.model_dir,
......@@ -115,30 +126,26 @@ def run_executor(params,
init_checkpoint=model_builder.make_restore_checkpoint_fn(),
custom_callbacks=callbacks,
save_config=True)
elif FLAGS.mode == 'eval' or FLAGS.mode == 'eval_once':
def _model_fn(params):
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
builder = executor.ExecutorBuilder(
strategy_type=params.strategy_type,
strategy_config=params.strategy_config)
num_workers = int(builder.strategy.num_replicas_in_sync + 7) // 8
is_multi_host = (int(num_workers) >= 2)
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
strategy.num_replicas_in_sync, num_workers, is_multi_host)
if is_multi_host:
eval_input_fn = functools.partial(
eval_input_fn,
batch_size=params.eval.batch_size //
builder.strategy.num_replicas_in_sync)
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
builder.strategy.num_replicas_in_sync, num_workers,
is_multi_host)
dist_executor = builder.build_executor(
class_ctor=DetectionDistributedExecutor,
batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)
dist_executor = DetectionDistributedExecutor(
strategy=strategy,
params=params,
is_multi_host=is_multi_host,
model_fn=_model_fn,
loss_fn=model_builder.build_loss_fn,
is_multi_host=is_multi_host,
predict_post_process_fn=model_builder.post_processing,
trainable_variables_filter=model_builder
.make_filter_trainable_variables_fn())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment