Commit 011422cf authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 408690416
parent 7f2c1d5a
......@@ -14,7 +14,7 @@
"""Segmentation heads."""
from typing import Any, Union, Sequence, Mapping
from typing import Any, Union, Sequence, Mapping, Tuple
import tensorflow as tf
from official.modeling import tf_utils
......@@ -139,25 +139,29 @@ class SegmentationHead3D(tf.keras.layers.Layer):
super(SegmentationHead3D, self).build(input_shape)
def call(self, backbone_output: Mapping[str, tf.Tensor],
decoder_output: Mapping[str, tf.Tensor]) -> tf.Tensor:
def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
"""Forward pass of the segmentation head.
Args:
backbone_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
It supports both a tuple of 2 tensors or 2 dictionaries. The first is
backbone endpoints, and the second is decoder endpoints. When inputs are
tensors, they are from a single level of feature maps. When inputs are
dictionaries, they contain multiple levels of feature maps, where the key
is the index of feature map.
Args:
inputs: A tuple of 2 feature map tensors of shape
[batch, height_l, width_l, channels] or 2 dictionaries of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features.
"""
x = decoder_output[str(self._config_dict['level'])]
decoder_output = inputs[1]
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
for i, conv in enumerate(self._convs):
x = conv(x)
......
......@@ -42,7 +42,7 @@ class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
'1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16),
}
logits = head(backbone_features, decoder_features)
logits = head((backbone_features, decoder_features))
if str(level) in decoder_features:
self.assertAllEqual(logits.numpy().shape, [
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping
from typing import List, Union, Optional, Mapping, Tuple
import tensorflow as tf
from official.modeling import tf_utils
......@@ -204,16 +204,19 @@ class SegmentationHead(tf.keras.layers.Layer):
super(SegmentationHead, self).build(input_shape)
def call(self, backbone_output: Mapping[str, tf.Tensor],
decoder_output: Mapping[str, tf.Tensor]):
def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
"""Forward pass of the segmentation head.
It supports both a tuple of 2 tensors or 2 dictionaries. The first is
backbone endpoints, and the second is decoder endpoints. When inputs are
tensors, they are from a single level of feature maps. When inputs are
dictionaries, they contain multiple levels of feature maps, where the key
is the index of feature map.
Args:
backbone_output: A `dict` of tensors
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
decoder_output: A `dict` of tensors
inputs: A tuple of 2 feature map tensors of shape
[batch, height_l, width_l, channels] or 2 dictionaries of tensors:
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
......@@ -221,11 +224,14 @@ class SegmentationHead(tf.keras.layers.Layer):
segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features.
"""
backbone_output = inputs[0]
decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
# deeplabv3+ feature fusion
x = decoder_output[str(self._config_dict['level'])]
y = backbone_output[str(
self._config_dict['low_level'])]
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
y = backbone_output[str(self._config_dict['low_level'])] if isinstance(
backbone_output, dict) else backbone_output
y = self._dlv3p_norm(self._dlv3p_conv(y))
y = self._activation(y)
......@@ -234,12 +240,15 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.cast(x, dtype=y.dtype)
x = tf.concat([x, y], axis=self._bn_axis)
elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
if not isinstance(decoder_output, dict):
raise ValueError('Only support dictionary decoder_output.')
x = nn_layers.pyramid_feature_fusion(decoder_output,
self._config_dict['level'])
elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
x = self._panoptic_fpn_fusion(decoder_output)
else:
x = decoder_output[str(self._config_dict['level'])]
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
......
......@@ -58,7 +58,7 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
decoder_max_level=decoder_max_level,
num_decoder_filters=64)
logits = head(backbone_features, decoder_features)
logits = head((backbone_features, decoder_features))
if level in decoder_features:
self.assertAllEqual(logits.numpy().shape, [
......
......@@ -62,7 +62,7 @@ class SegmentationModel(tf.keras.Model):
else:
decoder_features = backbone_features
return self.head(backbone_features, decoder_features)
return self.head((backbone_features, decoder_features))
@property
def checkpoint_items(
......
......@@ -171,7 +171,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
decoder_features = model_outputs['decoder_features']
segmentation_outputs = self.segmentation_head(
backbone_features, decoder_features, training=training)
(backbone_features, decoder_features), training=training)
model_outputs.update({
'segmentation_outputs': segmentation_outputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment