model_tpu_main.py 5.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Creates and runs `Estimator` for object detection model on TPUs.

This uses the TPUEstimator API to define and run a model in TRAIN/EVAL modes.
"""
# pylint: enable=line-too-long

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags
import tensorflow as tf


from object_detection import model_hparams
from object_detection import model_lib

tf.flags.DEFINE_bool('use_tpu', True, 'Use TPUs rather than plain CPUs')

# Cloud TPU Cluster Resolvers
flags.DEFINE_string(
    'gcp_project',
    default=None,
    help='Project name for the Cloud TPU-enabled project. If not specified, we '
    'will attempt to automatically detect the GCE project from metadata.')
flags.DEFINE_string(
    'tpu_zone',
    default=None,
    help='GCE zone where the Cloud TPU is located in. If not specified, we '
    'will attempt to automatically detect the GCE project from metadata.')
flags.DEFINE_string(
    'tpu_name',
    default=None,
48
    help='Name of the Cloud TPU for Cluster Resolvers.')
49
50
51
52
53
54
55

flags.DEFINE_integer('num_shards', 8, 'Number of shards (TPU cores).')
flags.DEFINE_integer('iterations_per_loop', 100,
                     'Number of iterations per TPU training loop.')
# For mode=train_and_eval, evaluation occurs after training is finished.
# Note: independently of steps_per_checkpoint, estimator will save the most
# recent checkpoint every 10 minutes by default for train_and_eval
56
57
flags.DEFINE_string('mode', 'train',
                    'Mode to run: train, eval')
58
59
60
flags.DEFINE_integer('train_batch_size', None, 'Batch size for training. If '
                     'this is not provided, batch size is read from training '
                     'config.')
61
62
63
64

flags.DEFINE_string(
    'hparams_overrides', None, 'Comma-separated list of '
    'hyperparameters to override defaults.')
65
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
66
67
flags.DEFINE_boolean('eval_training_data', False,
                     'If training data should be evaluated for this job.')
68
69
70
71
72
73
flags.DEFINE_integer('sample_1_of_n_eval_examples', 1, 'Will sample one of '
                     'every n eval input examples, where n is provided.')
flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample '
                     'one of every n train input examples for evaluation, '
                     'where n is provided. This is only used if '
                     '`eval_training_data` is True.')
74
75
76
77
78
79
80
81
82
83
84
85
86
flags.DEFINE_string(
    'model_dir', None, 'Path to output model directory '
    'where event and checkpoint files will be written.')
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
                    'file.')

FLAGS = tf.flags.FLAGS


def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')

87
  tpu_cluster_resolver = (
88
89
      tf.contrib.cluster_resolver.TPUClusterResolver(
          tpu=[FLAGS.tpu_name],
90
91
92
          zone=FLAGS.tpu_zone,
          project=FLAGS.gcp_project))
  tpu_grpc_url = tpu_cluster_resolver.get_master()
93

94
  config = tf.contrib.tpu.RunConfig(
95
96
97
      master=tpu_grpc_url,
      evaluation_master=tpu_grpc_url,
      model_dir=FLAGS.model_dir,
98
      tpu_config=tf.contrib.tpu.TPUConfig(
99
100
101
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_shards))

102
103
104
105
  kwargs = {}
  if FLAGS.train_batch_size:
    kwargs['batch_size'] = FLAGS.train_batch_size

106
107
108
109
110
  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
111
112
113
      sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
      sample_1_of_n_eval_on_train_examples=(
          FLAGS.sample_1_of_n_eval_on_train_examples),
114
115
116
      use_tpu_estimator=True,
      use_tpu=FLAGS.use_tpu,
      num_shards=FLAGS.num_shards,
117
      **kwargs)
118
119
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
120
  eval_input_fns = train_and_eval_dict['eval_input_fns']
121
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
122
123
  train_steps = train_and_eval_dict['train_steps']

124
  if FLAGS.mode == 'train':
125
126
127
128
    estimator.train(input_fn=train_input_fn, max_steps=train_steps)

  # Continuously evaluating.
  if FLAGS.mode == 'eval':
129
130
131
132
133
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
134
135
136
137
      # Currently only a single eval input is allowed.
      input_fn = eval_input_fns[0]
    model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, train_steps,
                              name)
138
139
140
141


if __name__ == '__main__':
  tf.app.run()