"scripts/git@developer.sourcefind.cn:change/sglang.git" did not exist on "25be63d0b2f6155fcf2f9790e3084c99db79752b"
Commit fd499ca2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 397467113
parent a6b0c050
...@@ -77,6 +77,7 @@ class RetinaNetModel(tf.keras.Model): ...@@ -77,6 +77,7 @@ class RetinaNetModel(tf.keras.Model):
images: tf.Tensor, images: tf.Tensor,
image_shape: Optional[tf.Tensor] = None, image_shape: Optional[tf.Tensor] = None,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
output_intermediate_features: bool = False,
training: bool = None) -> Mapping[str, tf.Tensor]: training: bool = None) -> Mapping[str, tf.Tensor]:
"""Forward pass of the RetinaNet model. """Forward pass of the RetinaNet model.
...@@ -92,6 +93,8 @@ class RetinaNetModel(tf.keras.Model): ...@@ -92,6 +93,8 @@ class RetinaNetModel(tf.keras.Model):
- key: `str`, the level of the multilevel predictions. - key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the anchor coordinates of a particular feature - values: `Tensor`, the anchor coordinates of a particular feature
level, whose shape is [height_l, width_l, num_anchors_per_location]. level, whose shape is [height_l, width_l, num_anchors_per_location].
output_intermediate_features: `bool` indicating whether to return the
intermediate feature maps generated by backbone and decoder.
training: `bool`, indicating whether it is in training mode. training: `bool`, indicating whether it is in training mode.
Returns: Returns:
...@@ -112,19 +115,26 @@ class RetinaNetModel(tf.keras.Model): ...@@ -112,19 +115,26 @@ class RetinaNetModel(tf.keras.Model):
feature level, whose shape is feature level, whose shape is
[batch, height_l, width_l, att_size * num_anchors_per_location]. [batch, height_l, width_l, att_size * num_anchors_per_location].
""" """
outputs = {}
# Feature extraction. # Feature extraction.
features = self.backbone(images) features = self.backbone(images)
if output_intermediate_features:
outputs.update(
{'backbone_{}'.format(k): v for k, v in features.items()})
if self.decoder: if self.decoder:
features = self.decoder(features) features = self.decoder(features)
if output_intermediate_features:
outputs.update(
{'decoder_{}'.format(k): v for k, v in features.items()})
# Dense prediction. `raw_attributes` can be empty. # Dense prediction. `raw_attributes` can be empty.
raw_scores, raw_boxes, raw_attributes = self.head(features) raw_scores, raw_boxes, raw_attributes = self.head(features)
if training: if training:
outputs = { outputs.update({
'cls_outputs': raw_scores, 'cls_outputs': raw_scores,
'box_outputs': raw_boxes, 'box_outputs': raw_boxes,
} })
if raw_attributes: if raw_attributes:
outputs.update({'attribute_outputs': raw_attributes}) outputs.update({'attribute_outputs': raw_attributes})
return outputs return outputs
...@@ -145,12 +155,13 @@ class RetinaNetModel(tf.keras.Model): ...@@ -145,12 +155,13 @@ class RetinaNetModel(tf.keras.Model):
[tf.shape(images)[0], 1, 1, 1]) [tf.shape(images)[0], 1, 1, 1])
# Post-processing. # Post-processing.
final_results = self.detection_generator( final_results = self.detection_generator(raw_boxes, raw_scores,
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes) anchor_boxes, image_shape,
outputs = { raw_attributes)
outputs.update({
'cls_outputs': raw_scores, 'cls_outputs': raw_scores,
'box_outputs': raw_boxes, 'box_outputs': raw_boxes,
} })
if self.detection_generator.get_config()['apply_nms']: if self.detection_generator.get_config()['apply_nms']:
outputs.update({ outputs.update({
'detection_boxes': final_results['detection_boxes'], 'detection_boxes': final_results['detection_boxes'],
......
...@@ -147,8 +147,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -147,8 +147,10 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
], ],
training=[True, False], training=[True, False],
has_att_heads=[True, False], has_att_heads=[True, False],
output_intermediate_features=[True, False],
)) ))
def test_forward(self, strategy, image_size, training, has_att_heads): def test_forward(self, strategy, image_size, training, has_att_heads,
output_intermediate_features):
"""Test for creation of a R50-FPN RetinaNet.""" """Test for creation of a R50-FPN RetinaNet."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
num_classes = 3 num_classes = 3
...@@ -202,6 +204,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -202,6 +204,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
images, images,
image_shape, image_shape,
anchor_boxes, anchor_boxes,
output_intermediate_features=output_intermediate_features,
training=training) training=training)
if training: if training:
...@@ -247,6 +250,19 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -247,6 +250,19 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual( self.assertAllEqual(
[2, 10, 1], [2, 10, 1],
model_outputs['detection_attributes']['depth'].numpy().shape) model_outputs['detection_attributes']['depth'].numpy().shape)
if output_intermediate_features:
for l in range(2, 6):
self.assertIn('backbone_{}'.format(l), model_outputs)
self.assertAllEqual([
2, image_size[0] // 2**l, image_size[1] // 2**l,
backbone.output_specs[str(l)].as_list()[-1]
], model_outputs['backbone_{}'.format(l)].numpy().shape)
for l in range(min_level, max_level + 1):
self.assertIn('decoder_{}'.format(l), model_outputs)
self.assertAllEqual([
2, image_size[0] // 2**l, image_size[1] // 2**l,
decoder.output_specs[str(l)].as_list()[-1]
], model_outputs['decoder_{}'.format(l)].numpy().shape)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized.""" """Validate the network can be serialized and deserialized."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment