Commit 011422cf authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 408690416
parent 7f2c1d5a
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Segmentation heads.""" """Segmentation heads."""
from typing import Any, Union, Sequence, Mapping from typing import Any, Union, Sequence, Mapping, Tuple
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -139,25 +139,29 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -139,25 +139,29 @@ class SegmentationHead3D(tf.keras.layers.Layer):
super(SegmentationHead3D, self).build(input_shape) super(SegmentationHead3D, self).build(input_shape)
def call(self, backbone_output: Mapping[str, tf.Tensor], def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
decoder_output: Mapping[str, tf.Tensor]) -> tf.Tensor: Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
"""Forward pass of the segmentation head. """Forward pass of the segmentation head.
Args: It supports both a tuple of 2 tensors or 2 dictionaries. The first is
backbone_output: a dict of tensors backbone endpoints, and the second is decoder endpoints. When inputs are
- key: `str`, the level of the multilevel features. tensors, they are from a single level of feature maps. When inputs are
- values: `Tensor`, the feature map tensors, whose shape is [batch, dictionaries, they contain multiple levels of feature maps, where the key
height_l, width_l, channels]. is the index of feature map.
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
Args:
inputs: A tuple of 2 feature map tensors of shape
[batch, height_l, width_l, channels] or 2 dictionaries of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns: Returns:
segmentation prediction mask: `Tensor`, the segmentation mask scores segmentation prediction mask: A `tf.Tensor` of the segmentation mask
predicted from input feature. scores predicted from input features.
""" """
x = decoder_output[str(self._config_dict['level'])] decoder_output = inputs[1]
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
for i, conv in enumerate(self._convs): for i, conv in enumerate(self._convs):
x = conv(x) x = conv(x)
......
...@@ -42,7 +42,7 @@ class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,7 +42,7 @@ class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
'1': np.random.rand(2, 128, 128, 128, 16), '1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16), '2': np.random.rand(2, 64, 64, 64, 16),
} }
logits = head(backbone_features, decoder_features) logits = head((backbone_features, decoder_features))
if str(level) in decoder_features: if str(level) in decoder_features:
self.assertAllEqual(logits.numpy().shape, [ self.assertAllEqual(logits.numpy().shape, [
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of segmentation heads.""" """Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping from typing import List, Union, Optional, Mapping, Tuple
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -204,16 +204,19 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -204,16 +204,19 @@ class SegmentationHead(tf.keras.layers.Layer):
super(SegmentationHead, self).build(input_shape) super(SegmentationHead, self).build(input_shape)
def call(self, backbone_output: Mapping[str, tf.Tensor], def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
decoder_output: Mapping[str, tf.Tensor]): Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
"""Forward pass of the segmentation head. """Forward pass of the segmentation head.
It supports both a tuple of 2 tensors or 2 dictionaries. The first is
backbone endpoints, and the second is decoder endpoints. When inputs are
tensors, they are from a single level of feature maps. When inputs are
dictionaries, they contain multiple levels of feature maps, where the key
is the index of feature map.
Args: Args:
backbone_output: A `dict` of tensors inputs: A tuple of 2 feature map tensors of shape
- key: A `str` of the level of the multilevel features. [batch, height_l, width_l, channels] or 2 dictionaries of tensors:
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
decoder_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features. - key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is - values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels]. [batch, height_l, width_l, channels].
...@@ -221,11 +224,14 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -221,11 +224,14 @@ class SegmentationHead(tf.keras.layers.Layer):
segmentation prediction mask: A `tf.Tensor` of the segmentation mask segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features. scores predicted from input features.
""" """
backbone_output = inputs[0]
decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus': if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# deeplabv3+ feature fusion # deeplabv3+ feature fusion
x = decoder_output[str(self._config_dict['level'])] x = decoder_output[str(self._config_dict['level'])] if isinstance(
y = backbone_output[str( decoder_output, dict) else decoder_output
self._config_dict['low_level'])] y = backbone_output[str(self._config_dict['low_level'])] if isinstance(
backbone_output, dict) else backbone_output
y = self._dlv3p_norm(self._dlv3p_conv(y)) y = self._dlv3p_norm(self._dlv3p_conv(y))
y = self._activation(y) y = self._activation(y)
...@@ -234,12 +240,15 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -234,12 +240,15 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.cast(x, dtype=y.dtype) x = tf.cast(x, dtype=y.dtype)
x = tf.concat([x, y], axis=self._bn_axis) x = tf.concat([x, y], axis=self._bn_axis)
elif self._config_dict['feature_fusion'] == 'pyramid_fusion': elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
if not isinstance(decoder_output, dict):
raise ValueError('Only support dictionary decoder_output.')
x = nn_layers.pyramid_feature_fusion(decoder_output, x = nn_layers.pyramid_feature_fusion(decoder_output,
self._config_dict['level']) self._config_dict['level'])
elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion': elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
x = self._panoptic_fpn_fusion(decoder_output) x = self._panoptic_fpn_fusion(decoder_output)
else: else:
x = decoder_output[str(self._config_dict['level'])] x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
for conv, norm in zip(self._convs, self._norms): for conv, norm in zip(self._convs, self._norms):
x = conv(x) x = conv(x)
......
...@@ -58,7 +58,7 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -58,7 +58,7 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
decoder_max_level=decoder_max_level, decoder_max_level=decoder_max_level,
num_decoder_filters=64) num_decoder_filters=64)
logits = head(backbone_features, decoder_features) logits = head((backbone_features, decoder_features))
if level in decoder_features: if level in decoder_features:
self.assertAllEqual(logits.numpy().shape, [ self.assertAllEqual(logits.numpy().shape, [
......
...@@ -62,7 +62,7 @@ class SegmentationModel(tf.keras.Model): ...@@ -62,7 +62,7 @@ class SegmentationModel(tf.keras.Model):
else: else:
decoder_features = backbone_features decoder_features = backbone_features
return self.head(backbone_features, decoder_features) return self.head((backbone_features, decoder_features))
@property @property
def checkpoint_items( def checkpoint_items(
......
...@@ -171,7 +171,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -171,7 +171,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
decoder_features = model_outputs['decoder_features'] decoder_features = model_outputs['decoder_features']
segmentation_outputs = self.segmentation_head( segmentation_outputs = self.segmentation_head(
backbone_features, decoder_features, training=training) (backbone_features, decoder_features), training=training)
model_outputs.update({ model_outputs.update({
'segmentation_outputs': segmentation_outputs, 'segmentation_outputs': segmentation_outputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment