Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9ca59f8a
Commit
9ca59f8a
authored
Oct 21, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Oct 21, 2019
Browse files
Internal change
PiperOrigin-RevId: 275867562
parent
befbe0f9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
23 deletions
+49
-23
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+2
-3
official/vision/detection/dataloader/input_reader.py
official/vision/detection/dataloader/input_reader.py
+3
-7
official/vision/detection/main.py
official/vision/detection/main.py
+42
-11
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+1
-2
official/vision/detection/utils/box_utils.py
official/vision/detection/utils/box_utils.py
+1
-0
No files found.
official/modeling/training/distributed_executor.py
View file @
9ca59f8a
...
...
@@ -253,8 +253,7 @@ class DistributedExecutor(object):
logging
.
warning
(
'model_dir is empty, so skip the save config.'
)
def
_get_input_iterator
(
self
,
input_fn
:
Callable
[[
Optional
[
params_dict
.
ParamsDict
]],
tf
.
data
.
Dataset
],
self
,
input_fn
:
Callable
[...,
tf
.
data
.
Dataset
],
strategy
:
tf
.
distribute
.
Strategy
)
->
Optional
[
Iterator
[
Any
]]:
"""Returns distributed dataset iterator.
...
...
@@ -275,7 +274,7 @@ class DistributedExecutor(object):
return
iter
(
strategy
.
experimental_distribute_datasets_from_function
(
input_fn
))
else
:
input_data
=
input_fn
(
self
.
_params
)
input_data
=
input_fn
()
return
iter
(
strategy
.
experimental_distribute_dataset
(
input_data
))
def
_create_replicated_step
(
self
,
...
...
official/vision/detection/dataloader/input_reader.py
View file @
9ca59f8a
...
...
@@ -58,16 +58,12 @@ class InputFn(object):
self
.
_parser_fn
=
factory
.
parser_generator
(
params
,
mode
)
self
.
_dataset_fn
=
tf
.
data
.
TFRecordDataset
def
__call__
(
self
,
params
:
params_dict
.
ParamsDict
=
None
,
batch_size
=
None
,
ctx
=
None
):
def
__call__
(
self
,
ctx
=
None
,
batch_size
:
int
=
None
):
"""Provides tf.data.Dataset object.
Args:
params: placeholder for model parameters.
batch_size: expected batch size input data.
ctx: context object.
batch_size: expected batch size input data.
Returns:
tf.data.Dataset object.
...
...
@@ -96,6 +92,6 @@ class InputFn(object):
# Parses the fetched records to input tensors for model function.
dataset
=
dataset
.
map
(
self
.
_parser_fn
,
num_parallel_calls
=
64
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
official/vision/detection/main.py
View file @
9ca59f8a
...
...
@@ -51,6 +51,9 @@ flags.DEFINE_string('training_file_pattern', None,
flags
.
DEFINE_string
(
'eval_file_pattern'
,
None
,
'Location of ther eval data'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'The checkpoint path to eval. Only used in eval_once mode.'
)
FLAGS
=
flags
.
FLAGS
...
...
@@ -71,8 +74,11 @@ def run_executor(params,
builder
=
executor
.
ExecutorBuilder
(
strategy_type
=
params
.
strategy_type
,
strategy_config
=
params
.
strategy_config
)
num_workers
=
(
builder
.
strategy
.
num_replicas_in_sync
+
7
)
/
8
is_multi_host
=
(
num_workers
>
1
)
num_workers
=
int
(
builder
.
strategy
.
num_replicas_in_sync
+
7
)
//
8
is_multi_host
=
(
int
(
num_workers
)
>=
2
)
logging
.
info
(
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s'
,
builder
.
strategy
.
num_replicas_in_sync
,
num_workers
,
is_multi_host
)
if
is_multi_host
:
train_input_fn
=
functools
.
partial
(
train_input_fn
,
...
...
@@ -97,7 +103,7 @@ def run_executor(params,
init_checkpoint
=
model_builder
.
make_restore_checkpoint_fn
(),
custom_callbacks
=
callbacks
,
save_config
=
True
)
elif
FLAGS
.
mode
==
'eval'
:
elif
FLAGS
.
mode
==
'eval'
or
FLAGS
.
mode
==
'eval_once'
:
def
_model_fn
(
params
):
return
model_builder
.
build_model
(
params
,
mode
=
ModeKeys
.
PREDICT_WITH_GT
)
...
...
@@ -105,22 +111,47 @@ def run_executor(params,
builder
=
executor
.
ExecutorBuilder
(
strategy_type
=
params
.
strategy_type
,
strategy_config
=
params
.
strategy_config
)
num_workers
=
int
(
builder
.
strategy
.
num_replicas_in_sync
+
7
)
//
8
is_multi_host
=
(
int
(
num_workers
)
>=
2
)
if
is_multi_host
:
eval_input_fn
=
functools
.
partial
(
eval_input_fn
,
batch_size
=
params
.
eval
.
batch_size
//
builder
.
strategy
.
num_replicas_in_sync
)
logging
.
info
(
'Eval num_replicas_in_sync %d num_workers %d is_multi_host %s'
,
builder
.
strategy
.
num_replicas_in_sync
,
num_workers
,
is_multi_host
)
dist_executor
=
builder
.
build_executor
(
class_ctor
=
DetectionDistributedExecutor
,
params
=
params
,
is_multi_host
=
is_multi_host
,
model_fn
=
_model_fn
,
loss_fn
=
model_builder
.
build_loss_fn
,
predict_post_process_fn
=
model_builder
.
post_processing
,
trainable_variables_filter
=
model_builder
.
make_filter_trainable_variables_fn
())
results
=
dist_executor
.
evaluate_from_model_dir
(
model_dir
=
params
.
model_dir
,
eval_input_fn
=
eval_input_fn
,
eval_metric_fn
=
model_builder
.
eval_metrics
,
eval_timeout
=
params
.
eval
.
eval_timeout
,
min_eval_interval
=
params
.
eval
.
min_eval_interval
,
total_steps
=
params
.
train
.
total_steps
)
if
FLAGS
.
mode
==
'eval'
:
results
=
dist_executor
.
evaluate_from_model_dir
(
model_dir
=
params
.
model_dir
,
eval_input_fn
=
eval_input_fn
,
eval_metric_fn
=
model_builder
.
eval_metrics
,
eval_timeout
=
params
.
eval
.
eval_timeout
,
min_eval_interval
=
params
.
eval
.
min_eval_interval
,
total_steps
=
params
.
train
.
total_steps
)
else
:
# Run evaluation once for a single checkpoint.
if
not
FLAGS
.
checkpoint_path
:
raise
ValueError
(
'FLAGS.checkpoint_path cannot be empty.'
)
checkpoint_path
=
FLAGS
.
checkpoint_path
if
tf
.
io
.
gfile
.
isdir
(
checkpoint_path
):
checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
checkpoint_path
)
summary_writer
=
executor
.
SummaryWriter
(
params
.
model_dir
,
'eval'
)
results
,
_
=
dist_executor
.
evaluate_checkpoint
(
checkpoint_path
=
checkpoint_path
,
eval_input_fn
=
eval_input_fn
,
eval_metric_fn
=
model_builder
.
eval_metrics
,
summary_writer
=
summary_writer
)
for
k
,
v
in
results
.
items
():
logging
.
info
(
'Final eval metric %s: %f'
,
k
,
v
)
return
results
...
...
@@ -182,7 +213,7 @@ def run(callbacks=None):
def
main
(
argv
):
del
argv
# Unused.
return
run
()
run
()
if
__name__
==
'__main__'
:
...
...
official/vision/detection/modeling/retinanet_model.py
View file @
9ca59f8a
...
...
@@ -60,8 +60,7 @@ class COCOMetrics(object):
return
self
.
_evaluator
.
evaluate
()
def
reset_states
(
self
):
logging
.
info
(
'State is reset on calling metric.result().'
)
pass
return
self
.
_evaluator
.
reset
()
class
RetinanetModel
(
base_model
.
Model
):
...
...
official/vision/detection/utils/box_utils.py
View file @
9ca59f8a
...
...
@@ -16,6 +16,7 @@
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
numpy
as
np
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment