model_builder.py 16.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A function to build a DetectionModel from configuration."""
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import box_predictor
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
31
32
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
33
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
34
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
35
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
36
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
37
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
38
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
39
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
40
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
41
42
43
44
45
from object_detection.protos import model_pb2

# A map of names to SSD feature extractors.
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
    'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
46
    'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
47
    'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
48
    'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor,
49
50
51
    'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
    'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
    'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
52
    'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
53
54
55
56
}

# A map of names to Faster R-CNN feature extractors.
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
Vivek Rathod's avatar
Vivek Rathod committed
57
58
    'faster_rcnn_nas':
    frcnn_nas.FasterRCNNNASFeatureExtractor,
59
60
    'faster_rcnn_pnas':
    frcnn_pnas.FasterRCNNPNASFeatureExtractor,
61
62
63
64
    'faster_rcnn_inception_resnet_v2':
    frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
    'faster_rcnn_inception_v2':
    frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
65
66
67
68
69
70
71
72
73
    'faster_rcnn_resnet50':
    frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
    'faster_rcnn_resnet101':
    frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
    'faster_rcnn_resnet152':
    frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}


74
def build(model_config, is_training, add_summaries=True):
75
76
77
78
79
80
  """Builds a DetectionModel based on the model config.

  Args:
    model_config: A model.proto object containing the config for the desired
      DetectionModel.
    is_training: True if this model is being built for training purposes.
81
    add_summaries: Whether to add tensorflow summaries in the model graph.
82
83
84
85
86
87
88
89
90
91
92

  Returns:
    DetectionModel based on the config.

  Raises:
    ValueError: On invalid meta architecture or model.
  """
  if not isinstance(model_config, model_pb2.DetectionModel):
    raise ValueError('model_config not of type model_pb2.DetectionModel.')
  meta_architecture = model_config.WhichOneof('model')
  if meta_architecture == 'ssd':
93
    return _build_ssd_model(model_config.ssd, is_training, add_summaries)
94
  if meta_architecture == 'faster_rcnn':
95
96
    return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
                                    add_summaries)
97
98
99
100
  raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))


def _build_ssd_feature_extractor(feature_extractor_config, is_training,
101
                                 reuse_weights=None):
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.

  Args:
    feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
    is_training: True if this feature extractor is being built for training.
    reuse_weights: if the feature extractor should reuse weights.

  Returns:
    ssd_meta_arch.SSDFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
  feature_type = feature_extractor_config.type
  depth_multiplier = feature_extractor_config.depth_multiplier
  min_depth = feature_extractor_config.min_depth
118
  pad_to_multiple = feature_extractor_config.pad_to_multiple
119
  use_explicit_padding = feature_extractor_config.use_explicit_padding
120
  use_depthwise = feature_extractor_config.use_depthwise
121
122
  conv_hyperparams = hyperparams_builder.build(
      feature_extractor_config.conv_hyperparams, is_training)
123
124
  override_base_feature_extractor_hyperparams = (
      feature_extractor_config.override_base_feature_extractor_hyperparams)
125
126
127
128
129

  if feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
    raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))

  feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
130
131
  return feature_extractor_class(
      is_training, depth_multiplier, min_depth, pad_to_multiple,
132
133
      conv_hyperparams, reuse_weights, use_explicit_padding, use_depthwise,
      override_base_feature_extractor_hyperparams)
134
135


136
def _build_ssd_model(ssd_config, is_training, add_summaries):
137
138
139
140
141
142
  """Builds an SSD detection model based on the model config.

  Args:
    ssd_config: A ssd.proto object containing the config for the desired
      SSDMetaArch.
    is_training: True if this model is being built for training purposes.
143
    add_summaries: Whether to add tf summaries in the model.
144
145
146

  Returns:
    SSDMetaArch based on the config.
147

148
149
150
151
152
153
154
  Raises:
    ValueError: If ssd_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
  num_classes = ssd_config.num_classes

  # Feature extractor
155
  feature_extractor = _build_ssd_feature_extractor(
156
      feature_extractor_config=ssd_config.feature_extractor,
157
      is_training=is_training)
158
159
160
161
162

  box_coder = box_coder_builder.build(ssd_config.box_coder)
  matcher = matcher_builder.build(ssd_config.matcher)
  region_similarity_calculator = sim_calc.build(
      ssd_config.similarity_calculator)
163
  encode_background_as_zeros = ssd_config.encode_background_as_zeros
164
  negative_class_weight = ssd_config.negative_class_weight
165
166
167
168
169
170
171
172
173
174
175
176
  ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
                                                  ssd_config.box_predictor,
                                                  is_training, num_classes)
  anchor_generator = anchor_generator_builder.build(
      ssd_config.anchor_generator)
  image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
  non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
      ssd_config.post_processing)
  (classification_loss, localization_loss, classification_weight,
   localization_weight,
   hard_example_miner) = losses_builder.build(ssd_config.loss)
  normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
177
  normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
178
179
180
181
182
183
184
185
186

  return ssd_meta_arch.SSDMetaArch(
      is_training,
      anchor_generator,
      ssd_box_predictor,
      box_coder,
      feature_extractor,
      matcher,
      region_similarity_calculator,
187
      encode_background_as_zeros,
188
      negative_class_weight,
189
190
191
192
193
194
195
196
      image_resizer_fn,
      non_max_suppression_fn,
      score_conversion_fn,
      classification_loss,
      localization_loss,
      classification_weight,
      localization_weight,
      normalize_loss_by_num_matches,
197
      hard_example_miner,
198
      add_summaries=add_summaries,
199
200
201
      normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
      freeze_batchnorm=ssd_config.freeze_batchnorm,
      inplace_batchnorm_update=ssd_config.inplace_batchnorm_update)
202
203
204


def _build_faster_rcnn_feature_extractor(
205
206
    feature_extractor_config, is_training, reuse_weights=None,
    inplace_batchnorm_update=False):
207
208
209
210
211
212
213
  """Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.

  Args:
    feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
      faster_rcnn.proto.
    is_training: True if this feature extractor is being built for training.
    reuse_weights: if the feature extractor should reuse weights.
214
215
216
217
218
    inplace_batchnorm_update: Whether to update batch_norm inplace during
      training. This is required for batch norm to work correctly on TPUs. When
      this is false, user must add a control dependency on
      tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
      norm moving average parameters.
219
220
221
222
223
224
225

  Returns:
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
226
227
  if inplace_batchnorm_update:
    raise ValueError('inplace batchnorm updates not supported.')
228
229
230
  feature_type = feature_extractor_config.type
  first_stage_features_stride = (
      feature_extractor_config.first_stage_features_stride)
231
  batch_norm_trainable = feature_extractor_config.batch_norm_trainable
232
233
234
235
236
237
238

  if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
    raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
        feature_type))
  feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
      feature_type]
  return feature_extractor_class(
239
240
      is_training, first_stage_features_stride,
      batch_norm_trainable, reuse_weights)
241
242


243
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
244
245
246
247
248
249
250
  """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
251
      desired FasterRCNNMetaArch or RFCNMetaArch.
252
    is_training: True if this model is being built for training purposes.
253
    add_summaries: Whether to add tf summaries in the model.
254
255
256

  Returns:
    FasterRCNNMetaArch based on the config.
257

258
259
260
261
262
263
264
265
  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
  num_classes = frcnn_config.num_classes
  image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

  feature_extractor = _build_faster_rcnn_feature_extractor(
266
267
      frcnn_config.feature_extractor, is_training,
      frcnn_config.inplace_batchnorm_update)
268

269
  number_of_stages = frcnn_config.number_of_stages
270
271
272
273
  first_stage_anchor_generator = anchor_generator_builder.build(
      frcnn_config.first_stage_anchor_generator)

  first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
274
  first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
      frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
  first_stage_box_predictor_kernel_size = (
      frcnn_config.first_stage_box_predictor_kernel_size)
  first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
  first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
  first_stage_positive_balance_fraction = (
      frcnn_config.first_stage_positive_balance_fraction)
  first_stage_nms_score_threshold = frcnn_config.first_stage_nms_score_threshold
  first_stage_nms_iou_threshold = frcnn_config.first_stage_nms_iou_threshold
  first_stage_max_proposals = frcnn_config.first_stage_max_proposals
  first_stage_loc_loss_weight = (
      frcnn_config.first_stage_localization_loss_weight)
  first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

  initial_crop_size = frcnn_config.initial_crop_size
  maxpool_kernel_size = frcnn_config.maxpool_kernel_size
  maxpool_stride = frcnn_config.maxpool_stride

  second_stage_box_predictor = box_predictor_builder.build(
      hyperparams_builder.build,
      frcnn_config.second_stage_box_predictor,
      is_training=is_training,
      num_classes=num_classes)
  second_stage_batch_size = frcnn_config.second_stage_batch_size
  second_stage_balance_fraction = frcnn_config.second_stage_balance_fraction
  (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
  ) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
  second_stage_localization_loss_weight = (
      frcnn_config.second_stage_localization_loss_weight)
304
305
306
  second_stage_classification_loss = (
      losses_builder.build_faster_rcnn_classification_loss(
          frcnn_config.second_stage_classification_loss))
307
308
  second_stage_classification_loss_weight = (
      frcnn_config.second_stage_classification_loss_weight)
309
310
  second_stage_mask_prediction_loss_weight = (
      frcnn_config.second_stage_mask_prediction_loss_weight)
311
312
313
314
315
316
317
318
319
320
321
322
323

  hard_example_miner = None
  if frcnn_config.HasField('hard_example_miner'):
    hard_example_miner = losses_builder.build_hard_example_miner(
        frcnn_config.hard_example_miner,
        second_stage_classification_loss_weight,
        second_stage_localization_loss_weight)

  common_kwargs = {
      'is_training': is_training,
      'num_classes': num_classes,
      'image_resizer_fn': image_resizer_fn,
      'feature_extractor': feature_extractor,
324
      'number_of_stages': number_of_stages,
325
326
      'first_stage_anchor_generator': first_stage_anchor_generator,
      'first_stage_atrous_rate': first_stage_atrous_rate,
327
328
      'first_stage_box_predictor_arg_scope_fn':
      first_stage_box_predictor_arg_scope_fn,
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
      'first_stage_box_predictor_kernel_size':
      first_stage_box_predictor_kernel_size,
      'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
      'first_stage_minibatch_size': first_stage_minibatch_size,
      'first_stage_positive_balance_fraction':
      first_stage_positive_balance_fraction,
      'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
      'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
      'first_stage_max_proposals': first_stage_max_proposals,
      'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
      'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
      'second_stage_batch_size': second_stage_batch_size,
      'second_stage_balance_fraction': second_stage_balance_fraction,
      'second_stage_non_max_suppression_fn':
      second_stage_non_max_suppression_fn,
      'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
      'second_stage_localization_loss_weight':
      second_stage_localization_loss_weight,
347
348
      'second_stage_classification_loss':
      second_stage_classification_loss,
349
350
      'second_stage_classification_loss_weight':
      second_stage_classification_loss_weight,
351
352
      'hard_example_miner': hard_example_miner,
      'add_summaries': add_summaries}
353
354
355
356
357
358
359
360
361
362
363

  if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
    return rfcn_meta_arch.RFCNMetaArch(
        second_stage_rfcn_box_predictor=second_stage_box_predictor,
        **common_kwargs)
  else:
    return faster_rcnn_meta_arch.FasterRCNNMetaArch(
        initial_crop_size=initial_crop_size,
        maxpool_kernel_size=maxpool_kernel_size,
        maxpool_stride=maxpool_stride,
        second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
364
365
        second_stage_mask_prediction_loss_weight=(
            second_stage_mask_prediction_loss_weight),
366
        **common_kwargs)