Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c21d3c25
Commit
c21d3c25
authored
Mar 05, 2018
by
Zhichao Lu
Committed by
lzc5123016
Mar 07, 2018
Browse files
Allow model.py to be extended with custom model building functions.
PiperOrigin-RevId: 187941168
parent
307f1f77
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
10 deletions
+37
-10
research/object_detection/inputs.py
research/object_detection/inputs.py
+4
-4
research/object_detection/model.py
research/object_detection/model.py
+33
-6
No files found.
research/object_detection/inputs.py
View file @
c21d3c25
...
...
@@ -200,8 +200,8 @@ def create_train_input_fn(train_config, train_input_config,
keypoints for each box.
Raises:
TypeError: if the `train_config`
or
`train_input_config`
are not of the
correct type.
TypeError: if the `train_config`
,
`train_input_config`
or `model_config`
are not of the
correct type.
"""
if
not
isinstance
(
train_config
,
train_pb2
.
TrainConfig
):
raise
TypeError
(
'For training mode, the `train_config` must be a '
...
...
@@ -316,8 +316,8 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
which represent instance masks for objects.
Raises:
TypeError: if the `eval_config`
or
`eval_input_config`
are not of the
correct type.
TypeError: if the `eval_config`
,
`eval_input_config`
or `model_config`
are not of the
correct type.
"""
del
params
if
not
isinstance
(
eval_config
,
eval_pb2
.
EvalConfig
):
...
...
research/object_detection/model.py
View file @
c21d3c25
...
...
@@ -32,6 +32,7 @@ import tensorflow as tf
from
google.protobuf
import
text_format
from
tensorflow.contrib.learn.python.learn
import
learn_runner
from
tensorflow.contrib.tpu.python.tpu
import
tpu_optimizer
from
tensorflow.python.lib.io
import
file_io
from
object_detection
import
eval_util
from
object_detection
import
inputs
from
object_detection
import
model_hparams
...
...
@@ -54,6 +55,20 @@ tf.flags.DEFINE_integer('num_eval_steps', 10000, 'Number of train steps.')
FLAGS
=
tf
.
flags
.
FLAGS
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP
=
{
'get_configs_from_pipeline_file'
:
config_util
.
get_configs_from_pipeline_file
,
'create_pipeline_proto_from_configs'
:
config_util
.
create_pipeline_proto_from_configs
,
'merge_external_params_with_configs'
:
config_util
.
merge_external_params_with_configs
,
'create_train_input_fn'
:
inputs
.
create_train_input_fn
,
'create_eval_input_fn'
:
inputs
.
create_eval_input_fn
,
'create_predict_input_fn'
:
inputs
.
create_predict_input_fn
,
}
def
_get_groundtruth_data
(
detection_model
,
class_agnostic
):
"""Extracts groundtruth data from detection_model.
...
...
@@ -413,8 +428,18 @@ def populate_experiment(run_config,
An `Experiment` that defines all aspects of training, evaluation, and
export.
"""
configs
=
config_util
.
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
config_util
.
merge_external_params_with_configs
(
get_configs_from_pipeline_file
=
MODEL_BUILD_UTIL_MAP
[
'get_configs_from_pipeline_file'
]
create_pipeline_proto_from_configs
=
MODEL_BUILD_UTIL_MAP
[
'create_pipeline_proto_from_configs'
]
merge_external_params_with_configs
=
MODEL_BUILD_UTIL_MAP
[
'merge_external_params_with_configs'
]
create_train_input_fn
=
MODEL_BUILD_UTIL_MAP
[
'create_train_input_fn'
]
create_eval_input_fn
=
MODEL_BUILD_UTIL_MAP
[
'create_eval_input_fn'
]
create_predict_input_fn
=
MODEL_BUILD_UTIL_MAP
[
'create_predict_input_fn'
]
configs
=
get_configs_from_pipeline_file
(
pipeline_config_path
)
configs
=
merge_external_params_with_configs
(
configs
,
hparams
,
train_steps
=
train_steps
,
...
...
@@ -436,18 +461,18 @@ def populate_experiment(run_config,
model_builder
.
build
,
model_config
=
model_config
)
# Create the input functions for TRAIN/EVAL.
train_input_fn
=
inputs
.
create_train_input_fn
(
train_input_fn
=
create_train_input_fn
(
train_config
=
train_config
,
train_input_config
=
train_input_config
,
model_config
=
model_config
)
eval_input_fn
=
inputs
.
create_eval_input_fn
(
eval_input_fn
=
create_eval_input_fn
(
eval_config
=
eval_config
,
eval_input_config
=
eval_input_config
,
model_config
=
model_config
)
export_strategies
=
[
tf
.
contrib
.
learn
.
utils
.
saved_model_export_utils
.
make_export_strategy
(
serving_input_fn
=
inputs
.
create_predict_input_fn
(
serving_input_fn
=
create_predict_input_fn
(
model_config
=
model_config
))
]
...
...
@@ -457,8 +482,10 @@ def populate_experiment(run_config,
if
run_config
.
is_chief
:
# Store the final pipeline config for traceability.
pipeline_config_final
=
config_util
.
create_pipeline_proto_from_configs
(
pipeline_config_final
=
create_pipeline_proto_from_configs
(
configs
)
if
not
file_io
.
file_exists
(
estimator
.
model_dir
):
file_io
.
recursive_create_dir
(
estimator
.
model_dir
)
pipeline_config_final_path
=
os
.
path
.
join
(
estimator
.
model_dir
,
'pipeline.config'
)
config_text
=
text_format
.
MessageToString
(
pipeline_config_final
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment