Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4577d2c9
"vscode:/vscode.git/clone" did not exist on "5f9cf110541c5fd3f6cb70e53effeaef151a9443"
Commit
4577d2c9
authored
Jan 31, 2020
by
Rajagopal Ananthanarayanan
Committed by
A. Unique TensorFlower
Jan 31, 2020
Browse files
Internal change
PiperOrigin-RevId: 292606219
parent
70e14f03
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
28 deletions
+35
-28
official/vision/detection/main.py
official/vision/detection/main.py
+35
-28
No files found.
official/vision/detection/main.py
View file @
4577d2c9
...
...
@@ -35,6 +35,7 @@ from official.vision.detection.dataloader import input_reader
from
official.vision.detection.dataloader
import
mode_keys
as
ModeKeys
from
official.vision.detection.executor.detection_executor
import
DetectionDistributedExecutor
from
official.vision.detection.modeling
import
factory
as
model_factory
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
hyperparams_flags
.
initialize_common_flags
()
...
...
@@ -68,7 +69,8 @@ FLAGS = flags.FLAGS
def
run_executor
(
params
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
callbacks
=
None
):
callbacks
=
None
,
strategy
=
None
):
"""Runs Retinanet model on distribution strategy defined by the user."""
if
params
.
architecture
.
use_bfloat16
:
...
...
@@ -78,35 +80,44 @@ def run_executor(params,
model_builder
=
model_factory
.
model_generator
(
params
)
if
strategy
is
None
:
strategy_config
=
params
.
strategy_config
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
strategy_type
,
num_gpus
=
strategy_config
.
num_gpus
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
num_packs
=
strategy_config
.
num_packs
,
tpu_address
=
strategy_config
.
tpu
)
num_workers
=
int
(
strategy
.
num_replicas_in_sync
+
7
)
//
8
is_multi_host
=
(
int
(
num_workers
)
>=
2
)
if
FLAGS
.
mode
==
'train'
:
def
_model_fn
(
params
):
return
model_builder
.
build_model
(
params
,
mode
=
ModeKeys
.
TRAIN
)
builder
=
executor
.
ExecutorBuilder
(
strategy_type
=
params
.
strategy_type
,
strategy_config
=
params
.
strategy_config
)
num_workers
=
int
(
builder
.
strategy
.
num_replicas_in_sync
+
7
)
//
8
is_multi_host
=
(
int
(
num_workers
)
>=
2
)
logging
.
info
(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s'
,
builder
.
strategy
.
num_replicas_in_sync
,
num_workers
,
is_multi_host
)
if
is_multi_host
:
train_input_fn
=
functools
.
partial
(
train_input_fn
,
batch_size
=
params
.
train
.
batch_size
//
builder
.
strategy
.
num_replicas_in_sync
)
strategy
.
num_replicas_in_sync
,
num_workers
,
is_multi_host
)
dist_executor
=
builder
.
build_e
xecutor
(
class_ctor
=
DetectionDistributedExecutor
,
dist_executor
=
DetectionDistributedE
xecutor
(
strategy
=
strategy
,
params
=
params
,
is_multi_host
=
is_multi_host
,
model_fn
=
_model_fn
,
loss_fn
=
model_builder
.
build_loss_fn
,
is_multi_host
=
is_multi_host
,
predict_post_process_fn
=
model_builder
.
post_processing
,
trainable_variables_filter
=
model_builder
.
make_filter_trainable_variables_fn
())
if
is_multi_host
:
train_input_fn
=
functools
.
partial
(
train_input_fn
,
batch_size
=
params
.
train
.
batch_size
//
strategy
.
num_replicas_in_sync
)
return
dist_executor
.
train
(
train_input_fn
=
train_input_fn
,
model_dir
=
params
.
model_dir
,
...
...
@@ -115,30 +126,26 @@ def run_executor(params,
init_checkpoint
=
model_builder
.
make_restore_checkpoint_fn
(),
custom_callbacks
=
callbacks
,
save_config
=
True
)
elif
FLAGS
.
mode
==
'eval'
or
FLAGS
.
mode
==
'eval_once'
:
def
_model_fn
(
params
):
return
model_builder
.
build_model
(
params
,
mode
=
ModeKeys
.
PREDICT_WITH_GT
)
builder
=
executor
.
ExecutorBuilder
(
strategy_type
=
params
.
strategy_type
,
strategy_config
=
params
.
strategy_config
)
num_workers
=
int
(
builder
.
strategy
.
num_replicas_in_sync
+
7
)
//
8
is_multi_host
=
(
int
(
num_workers
)
>=
2
)
logging
.
info
(
'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s'
,
strategy
.
num_replicas_in_sync
,
num_workers
,
is_multi_host
)
if
is_multi_host
:
eval_input_fn
=
functools
.
partial
(
eval_input_fn
,
batch_size
=
params
.
eval
.
batch_size
//
builder
.
strategy
.
num_replicas_in_sync
)
logging
.
info
(
'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s'
,
builder
.
strategy
.
num_replicas_in_sync
,
num_workers
,
is_multi_host
)
dist_executor
=
builder
.
build_executor
(
class_ctor
=
DetectionDistributedExecutor
,
batch_size
=
params
.
eval
.
batch_size
//
strategy
.
num_replicas_in_sync
)
dist_executor
=
DetectionDistributedExecutor
(
strategy
=
strategy
,
params
=
params
,
is_multi_host
=
is_multi_host
,
model_fn
=
_model_fn
,
loss_fn
=
model_builder
.
build_loss_fn
,
is_multi_host
=
is_multi_host
,
predict_post_process_fn
=
model_builder
.
post_processing
,
trainable_variables_filter
=
model_builder
.
make_filter_trainable_variables_fn
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment