Commit d088f0d5 authored by Cristina Vasconcelos's avatar Cristina Vasconcelos Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388951823
parent 027813d3
...@@ -76,7 +76,7 @@ def build_maskrcnn( ...@@ -76,7 +76,7 @@ def build_maskrcnn(
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config, norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
backbone(tf.keras.Input(input_specs.shape[1:])) backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
decoder = decoders.factory.build_decoder( decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
...@@ -119,6 +119,13 @@ def build_maskrcnn( ...@@ -119,6 +119,13 @@ def build_maskrcnn(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
name='detection_head') name='detection_head')
# Build backbone, decoder and region proposal network:
if decoder:
decoder_features = decoder(backbone_features)
rpn_head(decoder_features)
if roi_sampler_config.cascade_iou_thresholds: if roi_sampler_config.cascade_iou_thresholds:
detection_head_cascade = [detection_head] detection_head_cascade = [detection_head]
for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)): for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment