model_builder.py 23.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A function to build a DetectionModel from configuration."""
17

18
import functools
19

20
21
22
23
24
25
26
27
28
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
29
from object_detection.core import balanced_positive_negative_sampler as sampler
30
from object_detection.core import post_processing
31
from object_detection.core import target_assigner
32
33
34
35
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
36
37
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
38
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
39
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
40
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
41
from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
42
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
43
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
44
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
45
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
46
from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor
47
from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
48
from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor
49
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
50
from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
51
52
from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
53
from object_detection.predictors import rfcn_box_predictor
54
from object_detection.predictors.heads import mask_head
55
from object_detection.protos import model_pb2
56
from object_detection.utils import ops
57
58
59
60

# A map of names to SSD feature extractors.
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
    'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
61
    'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
62
    'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
63
64
    'ssd_mobilenet_v1_fpn': SSDMobileNetV1FpnFeatureExtractor,
    'ssd_mobilenet_v1_ppn': SSDMobileNetV1PpnFeatureExtractor,
65
    'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor,
66
    'ssd_mobilenet_v2_fpn': SSDMobileNetV2FpnFeatureExtractor,
67
68
69
    'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
    'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
    'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
70
71
72
73
74
    'ssd_resnet50_v1_ppn': ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor,
    'ssd_resnet101_v1_ppn':
        ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor,
    'ssd_resnet152_v1_ppn':
        ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor,
75
    'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
76
77
78
79
    'ssd_pnasnet': SSDPNASNetFeatureExtractor,
}

SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
80
    'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor,
81
    'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor
82
83
84
85
}

# A map of names to Faster R-CNN feature extractors.
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
Vivek Rathod's avatar
Vivek Rathod committed
86
87
    'faster_rcnn_nas':
    frcnn_nas.FasterRCNNNASFeatureExtractor,
88
89
    'faster_rcnn_pnas':
    frcnn_pnas.FasterRCNNPNASFeatureExtractor,
90
91
92
93
    'faster_rcnn_inception_resnet_v2':
    frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
    'faster_rcnn_inception_v2':
    frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
94
95
96
97
98
99
100
101
102
    'faster_rcnn_resnet50':
    frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
    'faster_rcnn_resnet101':
    frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
    'faster_rcnn_resnet152':
    frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}


103
def build(model_config, is_training, add_summaries=True):
104
105
106
107
108
109
  """Builds a DetectionModel based on the model config.

  Args:
    model_config: A model.proto object containing the config for the desired
      DetectionModel.
    is_training: True if this model is being built for training purposes.
110
    add_summaries: Whether to add tensorflow summaries in the model graph.
111
112
113
114
115
116
117
118
119
120
  Returns:
    DetectionModel based on the config.

  Raises:
    ValueError: On invalid meta architecture or model.
  """
  if not isinstance(model_config, model_pb2.DetectionModel):
    raise ValueError('model_config not of type model_pb2.DetectionModel.')
  meta_architecture = model_config.WhichOneof('model')
  if meta_architecture == 'ssd':
121
    return _build_ssd_model(model_config.ssd, is_training, add_summaries)
122
  if meta_architecture == 'faster_rcnn':
123
124
    return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
                                    add_summaries)
125
126
127
  raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))


128
129
130
def _build_ssd_feature_extractor(feature_extractor_config,
                                 is_training,
                                 freeze_batchnorm,
131
                                 reuse_weights=None):
132
133
134
135
136
  """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.

  Args:
    feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
    is_training: True if this feature extractor is being built for training.
137
138
139
140
    freeze_batchnorm: Whether to freeze batch norm parameters during
      training or not. When training with a small batch size (e.g. 1), it is
      desirable to freeze batch norm update and use pretrained batch norm
      params.
141
142
143
144
145
146
147
148
149
    reuse_weights: if the feature extractor should reuse weights.

  Returns:
    ssd_meta_arch.SSDFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
  feature_type = feature_extractor_config.type
150
  is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
151
152
  depth_multiplier = feature_extractor_config.depth_multiplier
  min_depth = feature_extractor_config.min_depth
153
  pad_to_multiple = feature_extractor_config.pad_to_multiple
154
  use_explicit_padding = feature_extractor_config.use_explicit_padding
155
  use_depthwise = feature_extractor_config.use_depthwise
156
157
158
159
160
161
162

  if is_keras_extractor:
    conv_hyperparams = hyperparams_builder.KerasLayerHyperparams(
        feature_extractor_config.conv_hyperparams)
  else:
    conv_hyperparams = hyperparams_builder.build(
        feature_extractor_config.conv_hyperparams, is_training)
163
164
  override_base_feature_extractor_hyperparams = (
      feature_extractor_config.override_base_feature_extractor_hyperparams)
165

166
167
  if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and (
      not is_keras_extractor):
168
169
    raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))

170
171
172
173
174
  if is_keras_extractor:
    feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
        feature_type]
  else:
    feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
  kwargs = {
      'is_training':
          is_training,
      'depth_multiplier':
          depth_multiplier,
      'min_depth':
          min_depth,
      'pad_to_multiple':
          pad_to_multiple,
      'use_explicit_padding':
          use_explicit_padding,
      'use_depthwise':
          use_depthwise,
      'override_base_feature_extractor_hyperparams':
          override_base_feature_extractor_hyperparams
  }

192
193
194
195
196
197
  if feature_extractor_config.HasField('replace_preprocessor_with_placeholder'):
    kwargs.update({
        'replace_preprocessor_with_placeholder':
            feature_extractor_config.replace_preprocessor_with_placeholder
    })

198
199
200
201
202
203
204
205
206
207
208
209
  if is_keras_extractor:
    kwargs.update({
        'conv_hyperparams': conv_hyperparams,
        'inplace_batchnorm_update': False,
        'freeze_batchnorm': freeze_batchnorm
    })
  else:
    kwargs.update({
        'conv_hyperparams_fn': conv_hyperparams,
        'reuse_weights': reuse_weights,
    })

210
211
  if feature_extractor_config.HasField('fpn'):
    kwargs.update({
212
213
214
215
216
217
        'fpn_min_level':
            feature_extractor_config.fpn.min_level,
        'fpn_max_level':
            feature_extractor_config.fpn.max_level,
        'additional_layer_depth':
            feature_extractor_config.fpn.additional_layer_depth,
218
219
220
    })

  return feature_extractor_class(**kwargs)
221
222


223
def _build_ssd_model(ssd_config, is_training, add_summaries):
224
225
226
227
228
229
  """Builds an SSD detection model based on the model config.

  Args:
    ssd_config: A ssd.proto object containing the config for the desired
      SSDMetaArch.
    is_training: True if this model is being built for training purposes.
230
    add_summaries: Whether to add tf summaries in the model.
231
232
  Returns:
    SSDMetaArch based on the config.
233

234
235
236
237
238
239
240
  Raises:
    ValueError: If ssd_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
  num_classes = ssd_config.num_classes

  # Feature extractor
241
  feature_extractor = _build_ssd_feature_extractor(
242
      feature_extractor_config=ssd_config.feature_extractor,
243
      freeze_batchnorm=ssd_config.freeze_batchnorm,
244
      is_training=is_training)
245
246
247
248
249

  box_coder = box_coder_builder.build(ssd_config.box_coder)
  matcher = matcher_builder.build(ssd_config.matcher)
  region_similarity_calculator = sim_calc.build(
      ssd_config.similarity_calculator)
250
  encode_background_as_zeros = ssd_config.encode_background_as_zeros
251
  negative_class_weight = ssd_config.negative_class_weight
252
253
  anchor_generator = anchor_generator_builder.build(
      ssd_config.anchor_generator)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
  if feature_extractor.is_keras_model:
    ssd_box_predictor = box_predictor_builder.build_keras(
        conv_hyperparams_fn=hyperparams_builder.KerasLayerHyperparams,
        freeze_batchnorm=ssd_config.freeze_batchnorm,
        inplace_batchnorm_update=False,
        num_predictions_per_location_list=anchor_generator
        .num_anchors_per_location(),
        box_predictor_config=ssd_config.box_predictor,
        is_training=is_training,
        num_classes=num_classes,
        add_background_class=ssd_config.add_background_class)
  else:
    ssd_box_predictor = box_predictor_builder.build(
        hyperparams_builder.build, ssd_config.box_predictor, is_training,
        num_classes, ssd_config.add_background_class)
269
270
271
272
  image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
  non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
      ssd_config.post_processing)
  (classification_loss, localization_loss, classification_weight,
273
274
   localization_weight, hard_example_miner, random_example_sampler,
   expected_loss_weights_fn) = losses_builder.build(ssd_config.loss)
275
  normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
276
  normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
277
278
279
280

  equalization_loss_config = ops.EqualizationLossConfig(
      weight=ssd_config.loss.equalization_loss.weight,
      exclude_prefixes=ssd_config.loss.equalization_loss.exclude_prefixes)
281
282
283
284
285

  target_assigner_instance = target_assigner.TargetAssigner(
      region_similarity_calculator,
      matcher,
      box_coder,
286
      negative_class_weight=negative_class_weight)
287

288
  ssd_meta_arch_fn = ssd_meta_arch.SSDMetaArch
289
  kwargs = {}
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

  return ssd_meta_arch_fn(
      is_training=is_training,
      anchor_generator=anchor_generator,
      box_predictor=ssd_box_predictor,
      box_coder=box_coder,
      feature_extractor=feature_extractor,
      encode_background_as_zeros=encode_background_as_zeros,
      image_resizer_fn=image_resizer_fn,
      non_max_suppression_fn=non_max_suppression_fn,
      score_conversion_fn=score_conversion_fn,
      classification_loss=classification_loss,
      localization_loss=localization_loss,
      classification_loss_weight=classification_weight,
      localization_loss_weight=localization_weight,
      normalize_loss_by_num_matches=normalize_loss_by_num_matches,
      hard_example_miner=hard_example_miner,
307
      target_assigner_instance=target_assigner_instance,
308
      add_summaries=add_summaries,
309
310
      normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
      freeze_batchnorm=ssd_config.freeze_batchnorm,
311
      inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
312
      add_background_class=ssd_config.add_background_class,
313
      explicit_background_class=ssd_config.explicit_background_class,
314
      random_example_sampler=random_example_sampler,
315
316
317
318
319
      expected_loss_weights_fn=expected_loss_weights_fn,
      use_confidences_as_targets=ssd_config.use_confidences_as_targets,
      implicit_example_weight=ssd_config.implicit_example_weight,
      equalization_loss_config=equalization_loss_config,
      **kwargs)
320
321
322


def _build_faster_rcnn_feature_extractor(
323
324
    feature_extractor_config, is_training, reuse_weights=None,
    inplace_batchnorm_update=False):
325
326
327
328
329
330
331
  """Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.

  Args:
    feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
      faster_rcnn.proto.
    is_training: True if this feature extractor is being built for training.
    reuse_weights: if the feature extractor should reuse weights.
332
333
334
335
336
    inplace_batchnorm_update: Whether to update batch_norm inplace during
      training. This is required for batch norm to work correctly on TPUs. When
      this is false, user must add a control dependency on
      tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
      norm moving average parameters.
337
338
339
340
341
342
343

  Returns:
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.

  Raises:
    ValueError: On invalid feature extractor type.
  """
344
345
  if inplace_batchnorm_update:
    raise ValueError('inplace batchnorm updates not supported.')
346
347
348
  feature_type = feature_extractor_config.type
  first_stage_features_stride = (
      feature_extractor_config.first_stage_features_stride)
349
  batch_norm_trainable = feature_extractor_config.batch_norm_trainable
350
351
352
353
354
355
356

  if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
    raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
        feature_type))
  feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
      feature_type]
  return feature_extractor_class(
357
358
      is_training, first_stage_features_stride,
      batch_norm_trainable, reuse_weights)
359
360


361
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
362
363
364
365
366
367
368
  """Builds a Faster R-CNN or R-FCN detection model based on the model config.

  Builds R-FCN model if the second_stage_box_predictor in the config is of type
  `rfcn_box_predictor` else builds a Faster R-CNN model.

  Args:
    frcnn_config: A faster_rcnn.proto object containing the config for the
369
      desired FasterRCNNMetaArch or RFCNMetaArch.
370
    is_training: True if this model is being built for training purposes.
371
    add_summaries: Whether to add tf summaries in the model.
372
373
374

  Returns:
    FasterRCNNMetaArch based on the config.
375

376
377
378
379
380
381
382
383
  Raises:
    ValueError: If frcnn_config.type is not recognized (i.e. not registered in
      model_class_map).
  """
  num_classes = frcnn_config.num_classes
  image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)

  feature_extractor = _build_faster_rcnn_feature_extractor(
384
      frcnn_config.feature_extractor, is_training,
385
      inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
386

387
  number_of_stages = frcnn_config.number_of_stages
388
389
390
  first_stage_anchor_generator = anchor_generator_builder.build(
      frcnn_config.first_stage_anchor_generator)

391
392
393
394
  first_stage_target_assigner = target_assigner.create_target_assigner(
      'FasterRCNN',
      'proposal',
      use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
395
  first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
396
  first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
397
398
399
400
401
      frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
  first_stage_box_predictor_kernel_size = (
      frcnn_config.first_stage_box_predictor_kernel_size)
  first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
  first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
402
403
  use_static_shapes = frcnn_config.use_static_shapes and (
      frcnn_config.use_static_shapes_for_eval or is_training)
404
405
  first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
      positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
406
407
      is_static=(frcnn_config.use_static_balanced_label_sampler and
                 use_static_shapes))
408
  first_stage_max_proposals = frcnn_config.first_stage_max_proposals
409
410
411
412
413
414
415
416
417
418
419
420
421
  if (frcnn_config.first_stage_nms_iou_threshold < 0 or
      frcnn_config.first_stage_nms_iou_threshold > 1.0):
    raise ValueError('iou_threshold not in [0, 1.0].')
  if (is_training and frcnn_config.second_stage_batch_size >
      first_stage_max_proposals):
    raise ValueError('second_stage_batch_size should be no greater than '
                     'first_stage_max_proposals.')
  first_stage_non_max_suppression_fn = functools.partial(
      post_processing.batch_multiclass_non_max_suppression,
      score_thresh=frcnn_config.first_stage_nms_score_threshold,
      iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
      max_size_per_class=frcnn_config.first_stage_max_proposals,
      max_total_size=frcnn_config.first_stage_max_proposals,
422
      use_static_shapes=use_static_shapes)
423
424
425
426
427
428
429
430
  first_stage_loc_loss_weight = (
      frcnn_config.first_stage_localization_loss_weight)
  first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight

  initial_crop_size = frcnn_config.initial_crop_size
  maxpool_kernel_size = frcnn_config.maxpool_kernel_size
  maxpool_stride = frcnn_config.maxpool_stride

431
432
433
434
  second_stage_target_assigner = target_assigner.create_target_assigner(
      'FasterRCNN',
      'detection',
      use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
435
436
437
438
439
440
  second_stage_box_predictor = box_predictor_builder.build(
      hyperparams_builder.build,
      frcnn_config.second_stage_box_predictor,
      is_training=is_training,
      num_classes=num_classes)
  second_stage_batch_size = frcnn_config.second_stage_batch_size
441
442
  second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
      positive_fraction=frcnn_config.second_stage_balance_fraction,
443
444
      is_static=(frcnn_config.use_static_balanced_label_sampler and
                 use_static_shapes))
445
446
447
448
  (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
  ) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
  second_stage_localization_loss_weight = (
      frcnn_config.second_stage_localization_loss_weight)
449
450
451
  second_stage_classification_loss = (
      losses_builder.build_faster_rcnn_classification_loss(
          frcnn_config.second_stage_classification_loss))
452
453
  second_stage_classification_loss_weight = (
      frcnn_config.second_stage_classification_loss_weight)
454
455
  second_stage_mask_prediction_loss_weight = (
      frcnn_config.second_stage_mask_prediction_loss_weight)
456
457
458
459
460
461
462
463

  hard_example_miner = None
  if frcnn_config.HasField('hard_example_miner'):
    hard_example_miner = losses_builder.build_hard_example_miner(
        frcnn_config.hard_example_miner,
        second_stage_classification_loss_weight,
        second_stage_localization_loss_weight)

464
465
466
  crop_and_resize_fn = (
      ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize
      else ops.native_crop_and_resize)
467
468
  clip_anchors_to_image = (
      frcnn_config.clip_anchors_to_image)
469

470
471
472
473
474
  common_kwargs = {
      'is_training': is_training,
      'num_classes': num_classes,
      'image_resizer_fn': image_resizer_fn,
      'feature_extractor': feature_extractor,
475
      'number_of_stages': number_of_stages,
476
      'first_stage_anchor_generator': first_stage_anchor_generator,
477
      'first_stage_target_assigner': first_stage_target_assigner,
478
      'first_stage_atrous_rate': first_stage_atrous_rate,
479
480
      'first_stage_box_predictor_arg_scope_fn':
      first_stage_box_predictor_arg_scope_fn,
481
482
483
484
      'first_stage_box_predictor_kernel_size':
      first_stage_box_predictor_kernel_size,
      'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
      'first_stage_minibatch_size': first_stage_minibatch_size,
485
      'first_stage_sampler': first_stage_sampler,
486
      'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn,
487
488
489
      'first_stage_max_proposals': first_stage_max_proposals,
      'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
      'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
490
      'second_stage_target_assigner': second_stage_target_assigner,
491
      'second_stage_batch_size': second_stage_batch_size,
492
      'second_stage_sampler': second_stage_sampler,
493
494
495
496
497
      'second_stage_non_max_suppression_fn':
      second_stage_non_max_suppression_fn,
      'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
      'second_stage_localization_loss_weight':
      second_stage_localization_loss_weight,
498
499
      'second_stage_classification_loss':
      second_stage_classification_loss,
500
501
      'second_stage_classification_loss_weight':
      second_stage_classification_loss_weight,
502
      'hard_example_miner': hard_example_miner,
503
      'add_summaries': add_summaries,
504
505
506
507
      'crop_and_resize_fn': crop_and_resize_fn,
      'clip_anchors_to_image': clip_anchors_to_image,
      'use_static_shapes': use_static_shapes,
      'resize_masks': frcnn_config.resize_masks
508
  }
509

510
511
  if isinstance(second_stage_box_predictor,
                rfcn_box_predictor.RfcnBoxPredictor):
512
513
514
515
516
517
518
519
520
    return rfcn_meta_arch.RFCNMetaArch(
        second_stage_rfcn_box_predictor=second_stage_box_predictor,
        **common_kwargs)
  else:
    return faster_rcnn_meta_arch.FasterRCNNMetaArch(
        initial_crop_size=initial_crop_size,
        maxpool_kernel_size=maxpool_kernel_size,
        maxpool_stride=maxpool_stride,
        second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
521
522
        second_stage_mask_prediction_loss_weight=(
            second_stage_mask_prediction_loss_weight),
523
        **common_kwargs)