Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fbe9b495
"vscode:/vscode.git/clone" did not exist on "fffcd2358c661eeb9124998ae36993818a6c97fa"
Commit
fbe9b495
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
model main tf2
parent
ca244433
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
106 additions
and
0 deletions
+106
-0
research/object_detection/model_main_tf2.py
research/object_detection/model_main_tf2.py
+106
-0
No files found.
research/object_detection/model_main_tf2.py
0 → 100644
View file @
fbe9b495
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Creates and runs TF2 object detection models.
For local training/evaluation run:
PIPELINE_CONFIG_PATH=path/to/pipeline.config
MODEL_DIR=/tmp/model_outputs
NUM_TRAIN_STEPS=10000
SAMPLE_1_OF_N_EVAL_EXAMPLES=1
python model_main_tf2.py -- \
--model_dir=$MODEL_DIR --num_train_steps=$NUM_TRAIN_STEPS \
--sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \
--pipeline_config_path=$PIPELINE_CONFIG_PATH \
--alsologtostderr
"""
from
absl
import
flags
import
tensorflow.compat.v2
as
tf
from
object_detection
import
model_lib_v2
flags
.
DEFINE_string
(
'pipeline_config_path'
,
None
,
'Path to pipeline config '
'file.'
)
flags
.
DEFINE_integer
(
'num_train_steps'
,
None
,
'Number of train steps.'
)
flags
.
DEFINE_bool
(
'eval_on_train_data'
,
False
,
'Enable evaluating on train '
'data (only supported in distributed training).'
)
flags
.
DEFINE_integer
(
'sample_1_of_n_eval_examples'
,
None
,
'Will sample one of '
'every n eval input examples, where n is provided.'
)
flags
.
DEFINE_integer
(
'sample_1_of_n_eval_on_train_examples'
,
5
,
'Will sample '
'one of every n train input examples for evaluation, '
'where n is provided. This is only used if '
'`eval_training_data` is True.'
)
flags
.
DEFINE_string
(
'model_dir'
,
None
,
'Path to output model directory '
'where event and checkpoint files will be written.'
)
flags
.
DEFINE_string
(
'checkpoint_dir'
,
None
,
'Path to directory holding a checkpoint. If '
'`checkpoint_dir` is provided, this binary operates in eval-only mode, '
'writing resulting metrics to `model_dir`.'
)
flags
.
DEFINE_integer
(
'eval_timeout'
,
3600
,
'Number of seconds to wait for an'
'evaluation checkpoint before exiting.'
)
flags
.
DEFINE_bool
(
'use_tpu'
,
False
,
'Whether the job is executing on a TPU.'
)
flags
.
DEFINE_string
(
'tpu_name'
,
default
=
None
,
help
=
'Name of the Cloud TPU for Cluster Resolvers.'
)
flags
.
DEFINE_integer
(
'num_workers'
,
1
,
'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
unused_argv
):
flags
.
mark_flag_as_required
(
'model_dir'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
tf
.
config
.
set_soft_device_placement
(
True
)
if
FLAGS
.
checkpoint_dir
:
model_lib_v2
.
eval_continuously
(
pipeline_config_path
=
FLAGS
.
pipeline_config_path
,
model_dir
=
FLAGS
.
model_dir
,
train_steps
=
FLAGS
.
num_train_steps
,
sample_1_of_n_eval_examples
=
FLAGS
.
sample_1_of_n_eval_examples
,
sample_1_of_n_eval_on_train_examples
=
(
FLAGS
.
sample_1_of_n_eval_on_train_examples
),
checkpoint_dir
=
FLAGS
.
checkpoint_dir
,
wait_interval
=
300
,
timeout
=
FLAGS
.
eval_timeout
)
else
:
if
FLAGS
.
use_tpu
:
# TPU is automatically inferred if tpu_name is None and
# we are running under cloud ai-platform.
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu_name
)
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
resolver
)
elif
FLAGS
.
num_workers
>
1
:
strategy
=
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
()
else
:
strategy
=
tf
.
compat
.
v2
.
distribute
.
MirroredStrategy
()
with
strategy
.
scope
():
model_lib_v2
.
train_loop
(
pipeline_config_path
=
FLAGS
.
pipeline_config_path
,
model_dir
=
FLAGS
.
model_dir
,
train_steps
=
FLAGS
.
num_train_steps
,
use_tpu
=
FLAGS
.
use_tpu
)
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
app
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment