model_builder.py 28.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A function to build a DetectionModel from configuration."""
17

18
import functools
19

20
21
22
23
24
25
26
27
28
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
29
from object_detection.core import balanced_positive_negative_sampler as sampler
30
from object_detection.core import post_processing
31
from object_detection.core import target_assigner
32
33
34
35
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
36
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
37
38
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
39
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
40
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
41
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
42
from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
43
from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
44
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
45
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
46
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
47
from object_detection.models.ssd_mobilenet_edgetpu_feature_extractor import SSDMobileNetEdgeTPUFeatureExtractor
48
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
49
from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor
50
from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor
51
from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
52
from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor
53
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
54
from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
55
from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor
56
from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
57
58
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3LargeFeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3SmallFeatureExtractor
59
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
60
from object_detection.predictors import rfcn_box_predictor
61
from object_detection.predictors import rfcn_keras_box_predictor
62
from object_detection.predictors.heads import mask_head
63
from object_detection.protos import model_pb2
64
from object_detection.utils import ops
65
66
67
68

# A map of names to SSD feature extractors.
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
    'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
69
    'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
70
    'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
71
72
    'ssd_mobilenet_v1_fpn': SSDMobileNetV1FpnFeatureExtractor,
    'ssd_mobilenet_v1_ppn': SSDMobileNetV1PpnFeatureExtractor,
73
    'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor,
74
    'ssd_mobilenet_v2_fpn': SSDMobileNetV2FpnFeatureExtractor,
75
76
    'ssd_mobilenet_v3_large': SSDMobileNetV3LargeFeatureExtractor,
    'ssd_mobilenet_v3_small': SSDMobileNetV3SmallFeatureExtractor,
77
    'ssd_mobilenet_edgetpu': SSDMobileNetEdgeTPUFeatureExtractor,
78
79
80
    'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
    'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
    'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
81
82
83
84
85
    'ssd_resnet50_v1_ppn': ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor,
    'ssd_resnet101_v1_ppn':
        ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor,
    'ssd_resnet152_v1_ppn':
        ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor,
86
    'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
87
88
89
90
    'ssd_pnasnet': SSDPNASNetFeatureExtractor,
}

SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
91
    'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor,
92
93
94
    'ssd_mobilenet_v1_fpn_keras': SSDMobileNetV1FpnKerasFeatureExtractor,
    'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor,
    'ssd_mobilenet_v2_fpn_keras': SSDMobileNetV2FpnKerasFeatureExtractor,
95
96
97
98
99
100
    'ssd_resnet50_v1_fpn_keras':
        ssd_resnet_v1_fpn_keras.SSDResNet50V1FpnKerasFeatureExtractor,
    'ssd_resnet101_v1_fpn_keras':
        ssd_resnet_v1_fpn_keras.SSDResNet101V1FpnKerasFeatureExtractor,
    'ssd_resnet152_v1_fpn_keras':
        ssd_resnet_v1_fpn_keras.SSDResNet152V1FpnKerasFeatureExtractor,
101
102
103
104
}

# A map of names to Faster R-CNN feature extractors.
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
Vivek Rathod's avatar
Vivek Rathod committed
105
106
    'faster_rcnn_nas':
    frcnn_nas.FasterRCNNNASFeatureExtractor,
107
108
    'faster_rcnn_pnas':
    frcnn_pnas.FasterRCNNPNASFeatureExtractor,
109
110
111
112
    'faster_rcnn_inception_resnet_v2':
    frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
    'faster_rcnn_inception_v2':
    frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
113
114
115
116
117
118
119
120
    'faster_rcnn_resnet50':
    frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
    'faster_rcnn_resnet101':
    frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
    'faster_rcnn_resnet152':
    frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}

121
122
123
124
125
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
    'faster_rcnn_inception_resnet_v2_keras':
    frcnn_inc_res_keras.FasterRCNNInceptionResnetV2KerasFeatureExtractor,
}

126

127
128
129
def _build_ssd_feature_extractor(feature_extractor_config,
                                 is_training,
                                 freeze_batchnorm,
130
                                 reuse_weights=None):
131
132
133
134
135
  """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.

  Args:
    feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
    is_training: True if this feature extractor is being built for training.
136
137
138
139
    freeze_batchnorm: Whether to freeze batch norm parameters during
      training or not. When training with a small batch size (e.g. 1), it is
      desirable to freeze batch norm update and use pretrained batch norm
      params.
140
141
142
143
144
145
146
147
148
    reuse_weights: if the feature extractor should reuse weights.

  Returns:
    ssd_meta_arch.SSDFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
  feature_type = feature_extractor_config.type
149
  is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
150
151
  depth_multiplier = feature_extractor_config.depth_multiplier
  min_depth = feature_extractor_config.min_depth
152
  pad_to_multiple = feature_extractor_config.pad_to_multiple
153
  use_explicit_padding = feature_extractor_config.use_explicit_padding
154
  use_depthwise = feature_extractor_config.use_depthwise
155
156
157
158
159
160
161

  if is_keras_extractor:
    conv_hyperparams = hyperparams_builder.KerasLayerHyperparams(
        feature_extractor_config.conv_hyperparams)
  else:
    conv_hyperparams = hyperparams_builder.build(
        feature_extractor_config.conv_hyperparams, is_training)
162
163
  override_base_feature_extractor_hyperparams = (
      feature_extractor_config.override_base_feature_extractor_hyperparams)
164

165
166
  if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and (
      not is_keras_extractor):
167
168
    raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))

169
170
171
172
173
  if is_keras_extractor:
    feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
        feature_type]
  else:
    feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
  kwargs = {
      'is_training':
          is_training,
      'depth_multiplier':
          depth_multiplier,
      'min_depth':
          min_depth,
      'pad_to_multiple':
          pad_to_multiple,
      'use_explicit_padding':
          use_explicit_padding,
      'use_depthwise':
          use_depthwise,
      'override_base_feature_extractor_hyperparams':
          override_base_feature_extractor_hyperparams
  }

191
192
193
194
195
196
  if feature_extractor_config.HasField('replace_preprocessor_with_placeholder'):
    kwargs.update({
        'replace_preprocessor_with_placeholder':
            feature_extractor_config.replace_preprocessor_with_placeholder
    })

pkulzc's avatar
pkulzc committed
197
198
199
  if feature_extractor_config.HasField('num_layers'):
    kwargs.update({'num_layers': feature_extractor_config.num_layers})

200
201
202
203
204
205
206
207
208
209
210
211
  if is_keras_extractor:
    kwargs.update({
        'conv_hyperparams': conv_hyperparams,
        'inplace_batchnorm_update': False,
        'freeze_batchnorm': freeze_batchnorm
    })
  else:
    kwargs.update({
        'conv_hyperparams_fn': conv_hyperparams,
        'reuse_weights': reuse_weights,
    })

212
213
  if feature_extractor_config.HasField('fpn'):
    kwargs.update({
214
215
216
217
218
219
        'fpn_min_level':
            feature_extractor_config.fpn.min_level,
        'fpn_max_level':
            feature_extractor_config.fpn.max_level,
        'additional_layer_depth':
            feature_extractor_config.fpn.additional_layer_depth,
220
221
    })

222

223
  return feature_extractor_class(**kwargs)
224
225


226
def _build_ssd_model(ssd_config, is_training, add_summaries):
227
228
229
230
231
232
  """Builds an SSD detection model based on the model config.

  Args:
    ssd_config: A ssd.proto object containing the config for the desired
      SSDMetaArch.
    is_training: True if this model is being built for training purposes.
233
    add_summaries: Whether to add tf summaries in the model.
234
235
  Returns:
    SSDMetaArch based on the config.
236

237
238
239
240
241
242
243
  Raises:
    ValueError: If ssd_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
  num_classes = ssd_config.num_classes

  # Feature extractor
244
  feature_extractor = _build_ssd_feature_extractor(
245
      feature_extractor_config=ssd_config.feature_extractor,
246
      freeze_batchnorm=ssd_config.freeze_batchnorm,
247
      is_training=is_training)
248
249
250
251
252

  box_coder = box_coder_builder.build(ssd_config.box_coder)
  matcher = matcher_builder.build(ssd_config.matcher)
  region_similarity_calculator = sim_calc.build(
      ssd_config.similarity_calculator)
253
  encode_background_as_zeros = ssd_config.encode_background_as_zeros
254
  negative_class_weight = ssd_config.negative_class_weight
255
256
  anchor_generator = anchor_generator_builder.build(
      ssd_config.anchor_generator)
257
258
  if feature_extractor.is_keras_model:
    ssd_box_predictor = box_predictor_builder.build_keras(
259
        hyperparams_fn=hyperparams_builder.KerasLayerHyperparams,
260
261
262
263
264
265
266
267
268
269
270
271
        freeze_batchnorm=ssd_config.freeze_batchnorm,
        inplace_batchnorm_update=False,
        num_predictions_per_location_list=anchor_generator
        .num_anchors_per_location(),
        box_predictor_config=ssd_config.box_predictor,
        is_training=is_training,
        num_classes=num_classes,
        add_background_class=ssd_config.add_background_class)
  else:
    ssd_box_predictor = box_predictor_builder.build(
        hyperparams_builder.build, ssd_config.box_predictor, is_training,
        num_classes, ssd_config.add_background_class)
272
273
274
275
  image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
  non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
      ssd_config.post_processing)
  (classification_loss, localization_loss, classification_weight,
276
277
   localization_weight, hard_example_miner, random_example_sampler,
   expected_loss_weights_fn) = losses_builder.build(ssd_config.loss)
278
  normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
279
  normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
280
281
282
283

  equalization_loss_config = ops.EqualizationLossConfig(
      weight=ssd_config.loss.equalization_loss.weight,
      exclude_prefixes=ssd_config.loss.equalization_loss.exclude_prefixes)
284
285
286
287
288

  target_assigner_instance = target_assigner.TargetAssigner(
      region_similarity_calculator,
      matcher,
      box_coder,
289
      negative_class_weight=negative_class_weight)
290

291
  ssd_meta_arch_fn = ssd_meta_arch.SSDMetaArch
292
  kwargs = {}
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

  return ssd_meta_arch_fn(
      is_training=is_training,
      anchor_generator=anchor_generator,
      box_predictor=ssd_box_predictor,
      box_coder=box_coder,
      feature_extractor=feature_extractor,
      encode_background_as_zeros=encode_background_as_zeros,
      image_resizer_fn=image_resizer_fn,
      non_max_suppression_fn=non_max_suppression_fn,
      score_conversion_fn=score_conversion_fn,
      classification_loss=classification_loss,
      localization_loss=localization_loss,
      classification_loss_weight=classification_weight,
      localization_loss_weight=localization_weight,
      normalize_loss_by_num_matches=normalize_loss_by_num_matches,
      hard_example_miner=hard_example_miner,
310
      target_assigner_instance=target_assigner_instance,
311
      add_summaries=add_summaries,
312
313
      normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
      freeze_batchnorm=ssd_config.freeze_batchnorm,
314
      inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
315
      add_background_class=ssd_config.add_background_class,
316
      explicit_background_class=ssd_config.explicit_background_class,
317
      random_example_sampler=random_example_sampler,
318
319
320
321
      expected_loss_weights_fn=expected_loss_weights_fn,
      use_confidences_as_targets=ssd_config.use_confidences_as_targets,
      implicit_example_weight=ssd_config.implicit_example_weight,
      equalization_loss_config=equalization_loss_config,
322
323
      return_raw_detections_during_predict=(
          ssd_config.return_raw_detections_during_predict),
324
      **kwargs)
325
326
327


def _build_faster_rcnn_feature_extractor(
328
329
    feature_extractor_config, is_training, reuse_weights=None,
    inplace_batchnorm_update=False):
330
331
332
333
334
335
336
  """Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.

  Args:
    feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
      faster_rcnn.proto.
    is_training: True if this feature extractor is being built for training.
    reuse_weights: if the feature extractor should reuse weights.
337
338
339
340
341
    inplace_batchnorm_update: Whether to update batch_norm inplace during
      training. This is required for batch norm to work correctly on TPUs. When
      this is false, user must add a control dependency on
      tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
      norm moving average parameters.
342
343
344
345
346
347
348

  Returns:
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
349
350
  if inplace_batchnorm_update:
    raise ValueError('inplace batchnorm updates not supported.')
351
352
353
  feature_type = feature_extractor_config.type
  first_stage_features_stride = (
      feature_extractor_config.first_stage_features_stride)
354
  batch_norm_trainable = feature_extractor_config.batch_norm_trainable
355
356
357
358
359
360
361

  if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
    raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
        feature_type))
  feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
      feature_type]
  return feature_extractor_class(
362
      is_training, first_stage_features_stride,
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
      batch_norm_trainable, reuse_weights=reuse_weights)


def _build_faster_rcnn_keras_feature_extractor(
    feature_extractor_config, is_training,
    inplace_batchnorm_update=False):
  """Builds a faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor from config.

  Args:
    feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
      faster_rcnn.proto.
    is_training: True if this feature extractor is being built for training.
    inplace_batchnorm_update: Whether to update batch_norm inplace during
      training. This is required for batch norm to work correctly on TPUs. When
      this is false, user must add a control dependency on
      tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
      norm moving average parameters.

  Returns:
    faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
  if inplace_batchnorm_update:
    raise ValueError('inplace batchnorm updates not supported.')
  feature_type = feature_extractor_config.type
  first_stage_features_stride = (
      feature_extractor_config.first_stage_features_stride)
  batch_norm_trainable = feature_extractor_config.batch_norm_trainable

  if feature_type not in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP:
    raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
        feature_type))
  feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
      feature_type]
  return feature_extractor_class(
      is_training, first_stage_features_stride,
      batch_norm_trainable)
402
403


404
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
405
406
407
408
409
410
411
  """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
412
      desired FasterRCNNMetaArch or RFCNMetaArch.
413
    is_training: True if this model is being built for training purposes.
414
    add_summaries: Whether to add tf summaries in the model.
415
416
417

  Returns:
    FasterRCNNMetaArch based on the config.
418

419
420
421
422
423
424
425
  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
  num_classes = frcnn_config.num_classes
  image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

426
427
428
429
430
431
432
433
434
435
436
  is_keras = (frcnn_config.feature_extractor.type in
              FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)

  if is_keras:
    feature_extractor = _build_faster_rcnn_keras_feature_extractor(
        frcnn_config.feature_extractor, is_training,
        inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
  else:
    feature_extractor = _build_faster_rcnn_feature_extractor(
        frcnn_config.feature_extractor, is_training,
        inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
437

438
  number_of_stages = frcnn_config.number_of_stages
439
440
441
  first_stage_anchor_generator = anchor_generator_builder.build(
      frcnn_config.first_stage_anchor_generator)

442
443
444
445
  first_stage_target_assigner = target_assigner.create_target_assigner(
      'FasterRCNN',
      'proposal',
      use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
446
  first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
447
448
449
450
451
452
453
  if is_keras:
    first_stage_box_predictor_arg_scope_fn = (
        hyperparams_builder.KerasLayerHyperparams(
            frcnn_config.first_stage_box_predictor_conv_hyperparams))
  else:
    first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
        frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
454
455
456
457
  first_stage_box_predictor_kernel_size = (
      frcnn_config.first_stage_box_predictor_kernel_size)
  first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
  first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
458
459
  use_static_shapes = frcnn_config.use_static_shapes and (
      frcnn_config.use_static_shapes_for_eval or is_training)
460
461
  first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
      positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
462
463
      is_static=(frcnn_config.use_static_balanced_label_sampler and
                 use_static_shapes))
464
  first_stage_max_proposals = frcnn_config.first_stage_max_proposals
465
466
467
468
469
470
471
472
473
474
475
476
477
  if (frcnn_config.first_stage_nms_iou_threshold < 0 or
      frcnn_config.first_stage_nms_iou_threshold > 1.0):
    raise ValueError('iou_threshold not in [0, 1.0].')
  if (is_training and frcnn_config.second_stage_batch_size >
      first_stage_max_proposals):
    raise ValueError('second_stage_batch_size should be no greater than '
                     'first_stage_max_proposals.')
  first_stage_non_max_suppression_fn = functools.partial(
      post_processing.batch_multiclass_non_max_suppression,
      score_thresh=frcnn_config.first_stage_nms_score_threshold,
      iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
      max_size_per_class=frcnn_config.first_stage_max_proposals,
      max_total_size=frcnn_config.first_stage_max_proposals,
Pooya Davoodi's avatar
Pooya Davoodi committed
478
      use_static_shapes=use_static_shapes,
479
      use_partitioned_nms=frcnn_config.use_partitioned_nms_in_first_stage,
Pooya Davoodi's avatar
Pooya Davoodi committed
480
      use_combined_nms=frcnn_config.use_combined_nms_in_first_stage)
481
482
483
484
485
486
487
488
  first_stage_loc_loss_weight = (
      frcnn_config.first_stage_localization_loss_weight)
  first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

  initial_crop_size = frcnn_config.initial_crop_size
  maxpool_kernel_size = frcnn_config.maxpool_kernel_size
  maxpool_stride = frcnn_config.maxpool_stride

489
490
491
492
  second_stage_target_assigner = target_assigner.create_target_assigner(
      'FasterRCNN',
      'detection',
      use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
  if is_keras:
    second_stage_box_predictor = box_predictor_builder.build_keras(
        hyperparams_builder.KerasLayerHyperparams,
        freeze_batchnorm=False,
        inplace_batchnorm_update=False,
        num_predictions_per_location_list=[1],
        box_predictor_config=frcnn_config.second_stage_box_predictor,
        is_training=is_training,
        num_classes=num_classes)
  else:
    second_stage_box_predictor = box_predictor_builder.build(
        hyperparams_builder.build,
        frcnn_config.second_stage_box_predictor,
        is_training=is_training,
        num_classes=num_classes)
508
  second_stage_batch_size = frcnn_config.second_stage_batch_size
509
510
  second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
      positive_fraction=frcnn_config.second_stage_balance_fraction,
511
512
      is_static=(frcnn_config.use_static_balanced_label_sampler and
                 use_static_shapes))
513
514
515
516
  (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
  ) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
  second_stage_localization_loss_weight = (
      frcnn_config.second_stage_localization_loss_weight)
517
518
519
  second_stage_classification_loss = (
      losses_builder.build_faster_rcnn_classification_loss(
          frcnn_config.second_stage_classification_loss))
520
521
  second_stage_classification_loss_weight = (
      frcnn_config.second_stage_classification_loss_weight)
522
523
  second_stage_mask_prediction_loss_weight = (
      frcnn_config.second_stage_mask_prediction_loss_weight)
524
525
526
527
528
529
530
531

  hard_example_miner = None
  if frcnn_config.HasField('hard_example_miner'):
    hard_example_miner = losses_builder.build_hard_example_miner(
        frcnn_config.hard_example_miner,
        second_stage_classification_loss_weight,
        second_stage_localization_loss_weight)

532
533
534
  crop_and_resize_fn = (
      ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize
      else ops.native_crop_and_resize)
535
536
  clip_anchors_to_image = (
      frcnn_config.clip_anchors_to_image)
537

538
539
540
541
542
  common_kwargs = {
      'is_training': is_training,
      'num_classes': num_classes,
      'image_resizer_fn': image_resizer_fn,
      'feature_extractor': feature_extractor,
543
      'number_of_stages': number_of_stages,
544
      'first_stage_anchor_generator': first_stage_anchor_generator,
545
      'first_stage_target_assigner': first_stage_target_assigner,
546
      'first_stage_atrous_rate': first_stage_atrous_rate,
547
548
      'first_stage_box_predictor_arg_scope_fn':
      first_stage_box_predictor_arg_scope_fn,
549
550
551
552
      'first_stage_box_predictor_kernel_size':
      first_stage_box_predictor_kernel_size,
      'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
      'first_stage_minibatch_size': first_stage_minibatch_size,
553
      'first_stage_sampler': first_stage_sampler,
554
      'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn,
555
556
557
      'first_stage_max_proposals': first_stage_max_proposals,
      'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
      'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
558
      'second_stage_target_assigner': second_stage_target_assigner,
559
      'second_stage_batch_size': second_stage_batch_size,
560
      'second_stage_sampler': second_stage_sampler,
561
562
563
564
565
      'second_stage_non_max_suppression_fn':
      second_stage_non_max_suppression_fn,
      'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
      'second_stage_localization_loss_weight':
      second_stage_localization_loss_weight,
566
567
      'second_stage_classification_loss':
      second_stage_classification_loss,
568
569
      'second_stage_classification_loss_weight':
      second_stage_classification_loss_weight,
570
      'hard_example_miner': hard_example_miner,
571
      'add_summaries': add_summaries,
572
573
574
      'crop_and_resize_fn': crop_and_resize_fn,
      'clip_anchors_to_image': clip_anchors_to_image,
      'use_static_shapes': use_static_shapes,
575
576
577
      'resize_masks': frcnn_config.resize_masks,
      'return_raw_detections_during_predict': (
          frcnn_config.return_raw_detections_during_predict)
578
  }
579

580
581
582
583
  if (isinstance(second_stage_box_predictor,
                 rfcn_box_predictor.RfcnBoxPredictor) or
      isinstance(second_stage_box_predictor,
                 rfcn_keras_box_predictor.RfcnKerasBoxPredictor)):
584
585
586
587
588
589
590
591
592
    return rfcn_meta_arch.RFCNMetaArch(
        second_stage_rfcn_box_predictor=second_stage_box_predictor,
        **common_kwargs)
  else:
    return faster_rcnn_meta_arch.FasterRCNNMetaArch(
        initial_crop_size=initial_crop_size,
        maxpool_kernel_size=maxpool_kernel_size,
        maxpool_stride=maxpool_stride,
        second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
593
594
        second_stage_mask_prediction_loss_weight=(
            second_stage_mask_prediction_loss_weight),
595
        **common_kwargs)
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

EXPERIMENTAL_META_ARCH_BUILDER_MAP = {
}


def _build_experimental_model(config, is_training, add_summaries=True):
  return EXPERIMENTAL_META_ARCH_BUILDER_MAP[config.name](
      is_training, add_summaries)

META_ARCHITECURE_BUILDER_MAP = {
    'ssd': _build_ssd_model,
    'faster_rcnn': _build_faster_rcnn_model,
    'experimental_model': _build_experimental_model
}


def build(model_config, is_training, add_summaries=True):
  """Builds a DetectionModel based on the model config.

  Args:
    model_config: A model.proto object containing the config for the desired
      DetectionModel.
    is_training: True if this model is being built for training purposes.
    add_summaries: Whether to add tensorflow summaries in the model graph.
  Returns:
    DetectionModel based on the config.

  Raises:
    ValueError: On invalid meta architecture or model.
  """
  if not isinstance(model_config, model_pb2.DetectionModel):
    raise ValueError('model_config not of type model_pb2.DetectionModel.')

  meta_architecture = model_config.WhichOneof('model')

  if meta_architecture not in META_ARCHITECURE_BUILDER_MAP:
    raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
  else:
    build_func = META_ARCHITECURE_BUILDER_MAP[meta_architecture]
    return build_func(getattr(model_config, meta_architecture), is_training,
                      add_summaries)