Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0d968ea2
Commit
0d968ea2
authored
Apr 20, 2020
by
Rajagopal Ananthanarayanan
Committed by
A. Unique TensorFlower
Apr 20, 2020
Browse files
Internal change
PiperOrigin-RevId: 307436543
parent
42919740
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
10 deletions
+14
-10
official/vision/detection/main.py
official/vision/detection/main.py
+14
-10
No files found.
official/vision/detection/main.py
View file @
0d968ea2
...
...
@@ -69,10 +69,12 @@ FLAGS = flags.FLAGS
def
run_executor
(
params
,
mode
,
checkpoint_path
=
None
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
callbacks
=
None
,
strategy
=
None
):
prebuilt_
strategy
=
None
):
"""Runs Retinanet model on distribution strategy defined by the user."""
if
params
.
architecture
.
use_bfloat16
:
...
...
@@ -82,7 +84,9 @@ def run_executor(params,
model_builder
=
model_factory
.
model_generator
(
params
)
if
strategy
is
None
:
if
prebuilt_strategy
is
not
None
:
strategy
=
prebuilt_strategy
else
:
strategy_config
=
params
.
strategy_config
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
...
...
@@ -96,7 +100,7 @@ def run_executor(params,
num_workers
=
int
(
strategy
.
num_replicas_in_sync
+
7
)
//
8
is_multi_host
=
(
int
(
num_workers
)
>=
2
)
if
FLAGS
.
mode
==
'train'
:
if
mode
==
'train'
:
def
_model_fn
(
params
):
return
model_builder
.
build_model
(
params
,
mode
=
ModeKeys
.
TRAIN
)
...
...
@@ -128,8 +132,7 @@ def run_executor(params,
init_checkpoint
=
model_builder
.
make_restore_checkpoint_fn
(),
custom_callbacks
=
callbacks
,
save_config
=
True
)
elif
FLAGS
.
mode
==
'eval'
or
FLAGS
.
mode
==
'eval_once'
:
elif
mode
==
'eval'
or
mode
==
'eval_once'
:
def
_model_fn
(
params
):
return
model_builder
.
build_model
(
params
,
mode
=
ModeKeys
.
PREDICT_WITH_GT
)
...
...
@@ -152,7 +155,7 @@ def run_executor(params,
trainable_variables_filter
=
model_builder
.
make_filter_trainable_variables_fn
())
if
FLAGS
.
mode
==
'eval'
:
if
mode
==
'eval'
:
results
=
dist_executor
.
evaluate_from_model_dir
(
model_dir
=
params
.
model_dir
,
eval_input_fn
=
eval_input_fn
,
...
...
@@ -162,9 +165,8 @@ def run_executor(params,
total_steps
=
params
.
train
.
total_steps
)
else
:
# Run evaluation once for a single checkpoint.
if
not
FLAGS
.
checkpoint_path
:
raise
ValueError
(
'FLAGS.checkpoint_path cannot be empty.'
)
checkpoint_path
=
FLAGS
.
checkpoint_path
if
not
checkpoint_path
:
raise
ValueError
(
'checkpoint_path cannot be empty.'
)
if
tf
.
io
.
gfile
.
isdir
(
checkpoint_path
):
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
checkpoint_path
)
summary_writer
=
executor
.
SummaryWriter
(
params
.
model_dir
,
'eval'
)
...
...
@@ -177,7 +179,7 @@ def run_executor(params,
logging
.
info
(
'Final eval metric %s: %f'
,
k
,
v
)
return
results
else
:
raise
ValueError
(
'Mode not found: %s.'
%
FLAGS
.
mode
)
raise
ValueError
(
'Mode not found: %s.'
%
mode
)
def
run
(
callbacks
=
None
):
...
...
@@ -239,6 +241,8 @@ def run(callbacks=None):
return
run_executor
(
params
,
FLAGS
.
mode
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
eval_input_fn
,
callbacks
=
callbacks
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment