Commit aca30f93 authored by Cristina Vasconcelos's avatar Cristina Vasconcelos Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 397347570
parent b768c248
...@@ -121,8 +121,7 @@ def build_maskrcnn( ...@@ -121,8 +121,7 @@ def build_maskrcnn(
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
name='detection_head') name='detection_head')
# Build backbone, decoder and region proposal network: # Builds decoder and region proposal network:
if decoder: if decoder:
decoder_features = decoder(backbone_features) decoder_features = decoder(backbone_features)
rpn_head(decoder_features) rpn_head(decoder_features)
...@@ -145,6 +144,7 @@ def build_maskrcnn( ...@@ -145,6 +144,7 @@ def build_maskrcnn(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
name='detection_head_{}'.format(cascade_num + 1)) name='detection_head_{}'.format(cascade_num + 1))
detection_head_cascade.append(detection_head) detection_head_cascade.append(detection_head)
detection_head = detection_head_cascade detection_head = detection_head_cascade
...@@ -260,7 +260,7 @@ def build_retinanet( ...@@ -260,7 +260,7 @@ def build_retinanet(
backbone_config=model_config.backbone, backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config, norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
backbone(tf.keras.Input(input_specs.shape[1:])) backbone_features = backbone(tf.keras.Input(input_specs.shape[1:]))
decoder = decoders.factory.build_decoder( decoder = decoders.factory.build_decoder(
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
...@@ -289,6 +289,11 @@ def build_retinanet( ...@@ -289,6 +289,11 @@ def build_retinanet(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
# Builds decoder and head so that their trainable weights are initialized
if decoder:
decoder_features = decoder(backbone_features)
_ = head(decoder_features)
detection_generator_obj = detection_generator.MultilevelDetectionGenerator( detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
apply_nms=generator_config.apply_nms, apply_nms=generator_config.apply_nms,
pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_top_k=generator_config.pre_nms_top_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment