Commit 19620a5d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10303 from srihari-humbarwadi:panoptic-fpn

PiperOrigin-RevId: 406308204
parents aa75f089 fa06f822
...@@ -65,10 +65,14 @@ class SegmentationHead(hyperparams.Config): ...@@ -65,10 +65,14 @@ class SegmentationHead(hyperparams.Config):
use_depthwise_convolution: bool = False use_depthwise_convolution: bool = False
prediction_kernel_size: int = 1 prediction_kernel_size: int = 1
upsample_factor: int = 1 upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion feature_fusion: Optional[
str] = None # None, deeplabv3plus, panoptic_fpn_fusion or pyramid_fusion
# deeplabv3plus feature fusion params # deeplabv3plus feature fusion params
low_level: Union[int, str] = 2 low_level: Union[int, str] = 2
low_level_num_filters: int = 48 low_level_num_filters: int = 48
# panoptic_fpn_fusion params
decoder_min_level: Optional[Union[int, str]] = None
decoder_max_level: Optional[Union[int, str]] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -32,11 +32,28 @@ import tensorflow as tf ...@@ -32,11 +32,28 @@ import tensorflow as tf
from official.vision.beta.evaluation import panoptic_quality from official.vision.beta.evaluation import panoptic_quality
def _crop_padding(mask, image_info):
"""Crops padded masks to match original image shape.
Args:
mask: a padded mask tensor.
image_info: a tensor that holds information about original and preprocessed
images.
Returns:
cropped and padded masks: tf.Tensor
"""
image_shape = tf.cast(image_info[0, :], tf.int32)
mask = tf.image.crop_to_bounding_box(
tf.expand_dims(mask, axis=-1), 0, 0,
image_shape[0], image_shape[1])
return tf.expand_dims(mask[:, :, 0], axis=0)
class PanopticQualityEvaluator: class PanopticQualityEvaluator:
"""Panoptic Quality metric class.""" """Panoptic Quality metric class."""
def __init__(self, num_categories, ignored_label, max_instances_per_category, def __init__(self, num_categories, ignored_label, max_instances_per_category,
offset, is_thing=None): offset, is_thing=None, rescale_predictions=False):
"""Constructs Panoptic Quality evaluation class. """Constructs Panoptic Quality evaluation class.
The class provides the interface to Panoptic Quality metrics_fn. The class provides the interface to Panoptic Quality metrics_fn.
...@@ -55,10 +72,14 @@ class PanopticQualityEvaluator: ...@@ -55,10 +72,14 @@ class PanopticQualityEvaluator:
`is_thing[category_id]` is True iff that category is a "thing" category `is_thing[category_id]` is True iff that category is a "thing" category
instead of "stuff." Default to `None`, and it means categories are not instead of "stuff." Default to `None`, and it means categories are not
classified into these two categories. classified into these two categories.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, groundtruths['image_info'] is used to rescale
predictions.
""" """
self._pq_metric_module = panoptic_quality.PanopticQuality( self._pq_metric_module = panoptic_quality.PanopticQuality(
num_categories, ignored_label, max_instances_per_category, offset) num_categories, ignored_label, max_instances_per_category, offset)
self._is_thing = is_thing self._is_thing = is_thing
self._rescale_predictions = rescale_predictions
self._required_prediction_fields = ['category_mask', 'instance_mask'] self._required_prediction_fields = ['category_mask', 'instance_mask']
self._required_groundtruth_fields = ['category_mask', 'instance_mask'] self._required_groundtruth_fields = ['category_mask', 'instance_mask']
self.reset_states() self.reset_states()
...@@ -110,6 +131,13 @@ class PanopticQualityEvaluator: ...@@ -110,6 +131,13 @@ class PanopticQualityEvaluator:
Required fields: Required fields:
- category_mask: a numpy array of uint16 of shape [batch_size, H, W]. - category_mask: a numpy array of uint16 of shape [batch_size, H, W].
- instance_mask: a numpy array of uint16 of shape [batch_size, H, W]. - instance_mask: a numpy array of uint16 of shape [batch_size, H, W].
- image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width],
[y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension /
original dimension.
predictions: a dictionary of tensors including the fields below. See predictions: a dictionary of tensors including the fields below. See
different parsers under `../dataloader` for more details. different parsers under `../dataloader` for more details.
Required fields: Required fields:
...@@ -132,4 +160,25 @@ class PanopticQualityEvaluator: ...@@ -132,4 +160,25 @@ class PanopticQualityEvaluator:
raise ValueError( raise ValueError(
'Missing the required key `{}` in groundtruths!'.format(k)) 'Missing the required key `{}` in groundtruths!'.format(k))
if self._rescale_predictions:
for idx in range(len(groundtruths['category_mask'])):
image_info = groundtruths['image_info'][idx]
groundtruths_ = {
'category_mask':
_crop_padding(groundtruths['category_mask'][idx], image_info),
'instance_mask':
_crop_padding(groundtruths['instance_mask'][idx], image_info),
}
predictions_ = {
'category_mask':
_crop_padding(predictions['category_mask'][idx], image_info),
'instance_mask':
_crop_padding(predictions['instance_mask'][idx], image_info),
}
groundtruths_, predictions_ = self._convert_to_numpy(
groundtruths_, predictions_)
self._pq_metric_module.compare_and_accumulate(
groundtruths_, predictions_)
else:
self._pq_metric_module.compare_and_accumulate(groundtruths, predictions) self._pq_metric_module.compare_and_accumulate(groundtruths, predictions)
...@@ -35,8 +35,11 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -35,8 +35,11 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction_kernel_size: int = 1, prediction_kernel_size: int = 1,
upsample_factor: int = 1, upsample_factor: int = 1,
feature_fusion: Optional[str] = None, feature_fusion: Optional[str] = None,
decoder_min_level: Optional[int] = None,
decoder_max_level: Optional[int] = None,
low_level: int = 2, low_level: int = 2,
low_level_num_filters: int = 48, low_level_num_filters: int = 48,
num_decoder_filters: int = 256,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
...@@ -60,15 +63,24 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -60,15 +63,24 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer. prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied. generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`,
`deeplabv3plus`, features from decoder_features[level] will be fused `panoptic_fpn_fusion`, or None. If `deeplabv3plus`, features from
with low level feature maps from backbone. If `pyramid_fusion`, decoder_features[level] will be fused with low level feature maps from
multiscale features will be resized and fused at the target level. backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level.
decoder_min_level: An `int` of minimum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
decoder_max_level: An `int` of maximum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
low_level: An `int` of backbone level to be used for feature fusion. It is low_level: An `int` of backbone level to be used for feature fusion. It is
used when feature_fusion is set to `deeplabv3plus`. used when feature_fusion is set to `deeplabv3plus`.
low_level_num_filters: An `int` of reduced number of filters for the low low_level_num_filters: An `int` of reduced number of filters for the low
level features before fusing it with higher level features. It is only level features before fusing it with higher level features. It is only
used when feature_fusion is set to `deeplabv3plus`. used when feature_fusion is set to `deeplabv3plus`.
num_decoder_filters: An `int` of number of filters in the decoder outputs.
It is only used when feature_fusion is set to `panoptic_fpn_fusion`.
activation: A `str` that indicates which activation is used, e.g. 'relu', activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc. 'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch use_sync_bn: A `bool` that indicates whether to use synchronized batch
...@@ -91,14 +103,17 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -91,14 +103,17 @@ class SegmentationHead(tf.keras.layers.Layer):
'prediction_kernel_size': prediction_kernel_size, 'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor, 'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion, 'feature_fusion': feature_fusion,
'decoder_min_level': decoder_min_level,
'decoder_max_level': decoder_max_level,
'low_level': low_level, 'low_level': low_level,
'low_level_num_filters': low_level_num_filters, 'low_level_num_filters': low_level_num_filters,
'num_decoder_filters': num_decoder_filters,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum, 'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon, 'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer
} }
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1 self._bn_axis = -1
...@@ -141,6 +156,17 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -141,6 +156,17 @@ class SegmentationHead(tf.keras.layers.Layer):
self._dlv3p_norm = bn_op( self._dlv3p_norm = bn_op(
name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs) name='segmentation_head_deeplabv3p_fusion_norm', **bn_kwargs)
elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
self._panoptic_fpn_fusion = nn_layers.PanopticFPNFusion(
min_level=self._config_dict['decoder_min_level'],
max_level=self._config_dict['decoder_max_level'],
target_level=self._config_dict['level'],
num_filters=self._config_dict['num_filters'],
num_fpn_filters=self._config_dict['num_decoder_filters'],
activation=self._config_dict['activation'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
# Segmentation head layers. # Segmentation head layers.
self._convs = [] self._convs = []
self._norms = [] self._norms = []
...@@ -210,6 +236,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -210,6 +236,8 @@ class SegmentationHead(tf.keras.layers.Layer):
elif self._config_dict['feature_fusion'] == 'pyramid_fusion': elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
x = nn_layers.pyramid_feature_fusion(decoder_output, x = nn_layers.pyramid_feature_fusion(decoder_output,
self._config_dict['level']) self._config_dict['level'])
elif self._config_dict['feature_fusion'] == 'panoptic_fpn_fusion':
x = self._panoptic_fpn_fusion(decoder_output)
else: else:
x = decoder_output[str(self._config_dict['level'])] x = decoder_output[str(self._config_dict['level'])]
......
...@@ -26,20 +26,38 @@ from official.vision.beta.modeling.heads import segmentation_heads ...@@ -26,20 +26,38 @@ from official.vision.beta.modeling.heads import segmentation_heads
class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(2, 'pyramid_fusion'), (2, 'pyramid_fusion', None, None),
(3, 'pyramid_fusion'), (3, 'pyramid_fusion', None, None),
) (2, 'panoptic_fpn_fusion', 2, 5),
def test_forward(self, level, feature_fusion): (2, 'panoptic_fpn_fusion', 2, 6),
head = segmentation_heads.SegmentationHead( (3, 'panoptic_fpn_fusion', 3, 5),
num_classes=10, level=level, feature_fusion=feature_fusion) (3, 'panoptic_fpn_fusion', 3, 6))
def test_forward(self, level, feature_fusion,
decoder_min_level, decoder_max_level):
backbone_features = { backbone_features = {
'3': np.random.rand(2, 128, 128, 16), '3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16), '4': np.random.rand(2, 64, 64, 16),
'5': np.random.rand(2, 32, 32, 16),
} }
decoder_features = { decoder_features = {
'3': np.random.rand(2, 128, 128, 16), '3': np.random.rand(2, 128, 128, 64),
'4': np.random.rand(2, 64, 64, 16), '4': np.random.rand(2, 64, 64, 64),
'5': np.random.rand(2, 32, 32, 64),
'6': np.random.rand(2, 16, 16, 64),
} }
if feature_fusion == 'panoptic_fpn_fusion':
backbone_features['2'] = np.random.rand(2, 256, 256, 16)
decoder_features['2'] = np.random.rand(2, 256, 256, 64)
head = segmentation_heads.SegmentationHead(
num_classes=10,
level=level,
feature_fusion=feature_fusion,
decoder_min_level=decoder_min_level,
decoder_max_level=decoder_max_level,
num_decoder_filters=64)
logits = head(backbone_features, decoder_features) logits = head(backbone_features, decoder_features)
if level in decoder_features: if level in decoder_features:
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
"""Contains common building blocks for neural networks.""" """Contains common building blocks for neural networks."""
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.ops import spatial_transform_ops
# Type annotations. # Type annotations.
...@@ -308,6 +310,110 @@ def pyramid_feature_fusion(inputs, target_level): ...@@ -308,6 +310,110 @@ def pyramid_feature_fusion(inputs, target_level):
return tf.math.add_n(resampled_feats) return tf.math.add_n(resampled_feats)
class PanopticFPNFusion(tf.keras.Model):
"""Creates a Panoptic FPN feature Fusion layer.
This implements feature fusion for semantic segmentation head from the paper:
Alexander Kirillov, Ross Girshick, Kaiming He and Piotr Dollar.
Panoptic Feature Pyramid Networks.
(https://arxiv.org/pdf/1901.02446.pdf)
"""
def __init__(
self,
min_level: int = 2,
max_level: int = 5,
target_level: int = 2,
num_filters: int = 128,
num_fpn_filters: int = 256,
activation: str = 'relu',
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes panoptic FPN feature fusion layer.
Args:
min_level: An `int` of minimum level to use in feature fusion.
max_level: An `int` of maximum level to use in feature fusion.
target_level: An `int` of the target feature level for feature fusion.
num_filters: An `int` number of filters in conv2d layers.
num_fpn_filters: An `int` number of filters in the FPN outputs
activation: A `str` name of the activation function.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
Returns:
A `float` `tf.Tensor` of shape [batch_size, feature_height, feature_width,
feature_channel].
"""
if target_level > max_level:
raise ValueError('target_level should be less than max_level')
self._config_dict = {
'min_level': min_level,
'max_level': max_level,
'target_level': target_level,
'num_filters': num_filters,
'num_fpn_filters': num_fpn_filters,
'activation': activation,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
norm = tfa.layers.GroupNormalization
conv2d = tf.keras.layers.Conv2D
activation_fn = tf_utils.get_activation(activation)
if tf.keras.backend.image_data_format() == 'channels_last':
norm_axis = -1
else:
norm_axis = 1
inputs = self._build_inputs(num_fpn_filters, min_level, max_level)
upscaled_features = []
for level in range(min_level, max_level + 1):
num_conv_layers = max(1, level - target_level)
x = inputs[str(level)]
for i in range(num_conv_layers):
x = conv2d(
filters=num_filters,
kernel_size=3,
padding='same',
kernel_initializer=tf.keras.initializers.VarianceScaling(),
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer)(x)
x = norm(groups=32, axis=norm_axis)(x)
x = activation_fn(x)
if level != target_level:
x = spatial_transform_ops.nearest_upsampling(x, scale=2)
upscaled_features.append(x)
fused_features = tf.math.add_n(upscaled_features)
self._output_specs = {str(target_level): fused_features.get_shape()}
super(PanopticFPNFusion, self).__init__(
inputs=inputs, outputs=fused_features, **kwargs)
def _build_inputs(self, num_filters: int,
min_level: int, max_level: int):
inputs = {}
for level in range(min_level, max_level + 1):
inputs[str(level)] = tf.keras.Input(shape=[None, None, num_filters])
return inputs
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@property
def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class Scale(tf.keras.layers.Layer): class Scale(tf.keras.layers.Layer):
"""Scales the input by a trainable scalar weight. """Scales the input by a trainable scalar weight.
......
...@@ -30,7 +30,7 @@ from official.vision.beta.configs import semantic_segmentation ...@@ -30,7 +30,7 @@ from official.vision.beta.configs import semantic_segmentation
SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel
SEGMENTATION_HEAD = semantic_segmentation.SegmentationHead SEGMENTATION_HEAD = semantic_segmentation.SegmentationHead
_COCO_INPUT_PATH_BASE = 'coco' _COCO_INPUT_PATH_BASE = 'coco/tfrecords'
_COCO_TRAIN_EXAMPLES = 118287 _COCO_TRAIN_EXAMPLES = 118287
_COCO_VAL_EXAMPLES = 5000 _COCO_VAL_EXAMPLES = 5000
...@@ -75,13 +75,17 @@ class DataConfig(maskrcnn.DataConfig): ...@@ -75,13 +75,17 @@ class DataConfig(maskrcnn.DataConfig):
@dataclasses.dataclass @dataclasses.dataclass
class PanopticSegmentationGenerator(hyperparams.Config): class PanopticSegmentationGenerator(hyperparams.Config):
"""Panoptic segmentation generator config."""
output_size: List[int] = dataclasses.field( output_size: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
mask_binarize_threshold: float = 0.5 mask_binarize_threshold: float = 0.5
score_threshold: float = 0.05 score_threshold: float = 0.5
things_overlap_threshold: float = 0.5
stuff_area_threshold: float = 4096.0
things_class_label: int = 1 things_class_label: int = 1
void_class_label: int = 0 void_class_label: int = 0
void_instance_id: int = 0 void_instance_id: int = 0
rescale_predictions: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -106,7 +110,8 @@ class Losses(maskrcnn.Losses): ...@@ -106,7 +110,8 @@ class Losses(maskrcnn.Losses):
default_factory=list) default_factory=list)
semantic_segmentation_use_groundtruth_dimension: bool = True semantic_segmentation_use_groundtruth_dimension: bool = True
semantic_segmentation_top_k_percent_pixels: float = 1.0 semantic_segmentation_top_k_percent_pixels: float = 1.0
semantic_segmentation_weight: float = 1.0 instance_segmentation_weight: float = 1.0
semantic_segmentation_weight: float = 0.5
@dataclasses.dataclass @dataclasses.dataclass
...@@ -114,10 +119,12 @@ class PanopticQualityEvaluator(hyperparams.Config): ...@@ -114,10 +119,12 @@ class PanopticQualityEvaluator(hyperparams.Config):
"""Panoptic Quality Evaluator config.""" """Panoptic Quality Evaluator config."""
num_categories: int = 2 num_categories: int = 2
ignored_label: int = 0 ignored_label: int = 0
max_instances_per_category: int = 100 max_instances_per_category: int = 256
offset: int = 256 * 256 * 256 offset: int = 256 * 256 * 256
is_thing: List[float] = dataclasses.field( is_thing: List[float] = dataclasses.field(
default_factory=list) default_factory=list)
rescale_predictions: bool = False
report_per_class_metrics: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -144,8 +151,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -144,8 +151,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
panoptic_quality_evaluator: PanopticQualityEvaluator = PanopticQualityEvaluator() # pylint: disable=line-too-long panoptic_quality_evaluator: PanopticQualityEvaluator = PanopticQualityEvaluator() # pylint: disable=line-too-long
@exp_factory.register_config_factory('panoptic_maskrcnn_resnetfpn_coco') @exp_factory.register_config_factory('panoptic_fpn_coco')
def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: def panoptic_fpn_coco() -> cfg.ExperimentConfig:
"""COCO panoptic segmentation with Panoptic Mask R-CNN.""" """COCO panoptic segmentation with Panoptic Mask R-CNN."""
train_batch_size = 64 train_batch_size = 64
eval_batch_size = 8 eval_batch_size = 8
...@@ -169,18 +176,25 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -169,18 +176,25 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
is_thing.append(True if idx <= num_thing_categories else False) is_thing.append(True if idx <= num_thing_categories else False)
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), runtime=cfg.RuntimeConfig(
mixed_precision_dtype='float32', enable_xla=True),
task=PanopticMaskRCNNTask( task=PanopticMaskRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
init_checkpoint_modules=['backbone'], init_checkpoint_modules=['backbone'],
model=PanopticMaskRCNN( model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator( panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024]), output_size=[640, 640], rescale_predictions=True),
stuff_classes_offset=90, stuff_classes_offset=90,
segmentation_model=SEGMENTATION_MODEL( segmentation_model=SEGMENTATION_MODEL(
num_classes=num_semantic_segmentation_classes, num_classes=num_semantic_segmentation_classes,
head=SEGMENTATION_HEAD(level=3))), head=SEGMENTATION_HEAD(
level=2,
num_convs=0,
num_filters=128,
decoder_min_level=2,
decoder_max_level=6,
feature_fusion='panoptic_fpn_fusion'))),
losses=Losses(l2_weight_decay=0.00004), losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train*'), input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train*'),
...@@ -192,13 +206,19 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -192,13 +206,19 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'), input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'),
is_training=False, is_training=False,
global_batch_size=eval_batch_size, global_batch_size=eval_batch_size,
parser=Parser(
segmentation_resize_eval_groundtruth=False,
segmentation_groundtruth_padded_size=[640, 640]),
drop_remainder=False), drop_remainder=False),
annotation_file=os.path.join(_COCO_INPUT_PATH_BASE, annotation_file=os.path.join(_COCO_INPUT_PATH_BASE,
'instances_val2017.json'), 'instances_val2017.json'),
segmentation_evaluation=semantic_segmentation.Evaluation(
report_per_class_iou=False, report_train_mean_iou=False),
panoptic_quality_evaluator=PanopticQualityEvaluator( panoptic_quality_evaluator=PanopticQualityEvaluator(
num_categories=num_panoptic_categories, num_categories=num_panoptic_categories,
ignored_label=0, ignored_label=0,
is_thing=is_thing)), is_thing=is_thing,
rescale_predictions=True)),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
train_steps=22500, train_steps=22500,
validation_steps=validation_steps, validation_steps=validation_steps,
......
...@@ -25,7 +25,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_mas ...@@ -25,7 +25,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_mas
class PanopticMaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase): class PanopticMaskRCNNConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
('panoptic_maskrcnn_resnetfpn_coco',), ('panoptic_fpn_coco',),
) )
def test_panoptic_maskrcnn_configs(self, config_name): def test_panoptic_maskrcnn_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
......
...@@ -228,8 +228,8 @@ class Parser(maskrcnn_input.Parser): ...@@ -228,8 +228,8 @@ class Parser(maskrcnn_input.Parser):
image = image_mask[:, :, :-1] image = image_mask[:, :, :-1]
data['image'] = image data['image'] = image
data['boxes'] = boxes data['groundtruth_boxes'] = boxes
data['masks'] = masks data['groundtruth_instance_masks'] = masks
image, labels = super(Parser, self)._parse_train_data(data) image, labels = super(Parser, self)._parse_train_data(data)
...@@ -334,7 +334,9 @@ class Parser(maskrcnn_input.Parser): ...@@ -334,7 +334,9 @@ class Parser(maskrcnn_input.Parser):
panoptic_instance_mask = panoptic_instance_mask[:, :, 0] panoptic_instance_mask = panoptic_instance_mask[:, :, 0]
labels['groundtruths'].update({ labels['groundtruths'].update({
'gt_panoptic_category_mask': panoptic_category_mask, 'gt_panoptic_category_mask':
'gt_panoptic_instance_mask': panoptic_instance_mask}) tf.cast(panoptic_category_mask, dtype=tf.int32),
'gt_panoptic_instance_mask':
tf.cast(panoptic_instance_mask, dtype=tf.int32)})
return image, labels return image, labels
...@@ -69,8 +69,10 @@ def build_panoptic_maskrcnn( ...@@ -69,8 +69,10 @@ def build_panoptic_maskrcnn(
input_specs=segmentation_decoder_input_specs, input_specs=segmentation_decoder_input_specs,
model_config=segmentation_config, model_config=segmentation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
decoder_config = segmentation_decoder.get_config()
else: else:
segmentation_decoder = None segmentation_decoder = None
decoder_config = maskrcnn_model.decoder.get_config()
segmentation_head_config = segmentation_config.head segmentation_head_config = segmentation_config.head
detection_head_config = model_config.detection_head detection_head_config = model_config.detection_head
...@@ -84,12 +86,15 @@ def build_panoptic_maskrcnn( ...@@ -84,12 +86,15 @@ def build_panoptic_maskrcnn(
num_filters=segmentation_head_config.num_filters, num_filters=segmentation_head_config.num_filters,
upsample_factor=segmentation_head_config.upsample_factor, upsample_factor=segmentation_head_config.upsample_factor,
feature_fusion=segmentation_head_config.feature_fusion, feature_fusion=segmentation_head_config.feature_fusion,
decoder_min_level=segmentation_head_config.decoder_min_level,
decoder_max_level=segmentation_head_config.decoder_max_level,
low_level=segmentation_head_config.low_level, low_level=segmentation_head_config.low_level,
low_level_num_filters=segmentation_head_config.low_level_num_filters, low_level_num_filters=segmentation_head_config.low_level_num_filters,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
num_decoder_filters=decoder_config['num_filters'],
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
if model_config.generate_panoptic_masks: if model_config.generate_panoptic_masks:
...@@ -101,9 +106,12 @@ def build_panoptic_maskrcnn( ...@@ -101,9 +106,12 @@ def build_panoptic_maskrcnn(
stuff_classes_offset=model_config.stuff_classes_offset, stuff_classes_offset=model_config.stuff_classes_offset,
mask_binarize_threshold=mask_binarize_threshold, mask_binarize_threshold=mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold, score_threshold=postprocessing_config.score_threshold,
things_overlap_threshold=postprocessing_config.things_overlap_threshold,
things_class_label=postprocessing_config.things_class_label, things_class_label=postprocessing_config.things_class_label,
stuff_area_threshold=postprocessing_config.stuff_area_threshold,
void_class_label=postprocessing_config.void_class_label, void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id) void_instance_id=postprocessing_config.void_instance_id,
rescale_predictions=postprocessing_config.rescale_predictions)
else: else:
panoptic_segmentation_generator_obj = None panoptic_segmentation_generator_obj = None
......
...@@ -27,19 +27,18 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory ...@@ -27,19 +27,18 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory
class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase): class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
('resnet', (640, 640), 'dilated_resnet', 'fpn'), ('resnet', (640, 640), 'dilated_resnet', 'fpn', 'panoptic_fpn_fusion'),
('resnet', (640, 640), 'dilated_resnet', 'aspp'), ('resnet', (640, 640), 'dilated_resnet', 'aspp', 'deeplabv3plus'),
('resnet', (640, 640), None, 'fpn'), ('resnet', (640, 640), None, 'fpn', 'panoptic_fpn_fusion'),
('resnet', (640, 640), None, 'aspp'), ('resnet', (640, 640), None, 'aspp', 'deeplabv3plus'),
('resnet', (640, 640), None, None), ('resnet', (640, 640), None, None, 'panoptic_fpn_fusion'),
('resnet', (None, None), 'dilated_resnet', 'fpn'), ('resnet', (None, None), 'dilated_resnet', 'fpn', 'panoptic_fpn_fusion'),
('resnet', (None, None), 'dilated_resnet', 'aspp'), ('resnet', (None, None), 'dilated_resnet', 'aspp', 'deeplabv3plus'),
('resnet', (None, None), None, 'fpn'), ('resnet', (None, None), None, 'fpn', 'panoptic_fpn_fusion'),
('resnet', (None, None), None, 'aspp'), ('resnet', (None, None), None, 'aspp', 'deeplabv3plus'),
('resnet', (None, None), None, None) ('resnet', (None, None), None, None, 'deeplabv3plus'))
)
def test_builder(self, backbone_type, input_size, segmentation_backbone_type, def test_builder(self, backbone_type, input_size, segmentation_backbone_type,
segmentation_decoder_type): segmentation_decoder_type, fusion_type):
num_classes = 2 num_classes = 2
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
"""Contains definition for postprocessing layer to genrate panoptic segmentations.""" """Contains definition for postprocessing layer to genrate panoptic segmentations."""
from typing import List from typing import List, Optional
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import paste_masks
class PanopticSegmentationGenerator(tf.keras.layers.Layer): class PanopticSegmentationGenerator(tf.keras.layers.Layer):
"""Panoptic segmentation generator layer.""" """Panoptic segmentation generator layer."""
...@@ -28,10 +30,13 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -28,10 +30,13 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
max_num_detections: int, max_num_detections: int,
stuff_classes_offset: int, stuff_classes_offset: int,
mask_binarize_threshold: float = 0.5, mask_binarize_threshold: float = 0.5,
score_threshold: float = 0.05, score_threshold: float = 0.5,
things_overlap_threshold: float = 0.5,
stuff_area_threshold: float = 4096,
things_class_label: int = 1, things_class_label: int = 1,
void_class_label: int = 0, void_class_label: int = 0,
void_instance_id: int = -1, void_instance_id: int = -1,
rescale_predictions: bool = False,
**kwargs): **kwargs):
"""Generates panoptic segmentation masks. """Generates panoptic segmentation masks.
...@@ -45,6 +50,10 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -45,6 +50,10 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
mask_binarize_threshold: A `float` mask_binarize_threshold: A `float`
score_threshold: A `float` representing the threshold for deciding score_threshold: A `float` representing the threshold for deciding
when to remove objects based on score. when to remove objects based on score.
things_overlap_threshold: A `float` representing a threshold for deciding
to ignore a thing if overlap is above the threshold.
stuff_area_threshold: A `float` representing a threshold for deciding to
to ignore a stuff class if area is below certain threshold.
things_class_label: An `int` that represents a single merged category of things_class_label: An `int` that represents a single merged category of
all thing classes in the semantic segmentation output. all thing classes in the semantic segmentation output.
void_class_label: An `int` that is used to represent empty or unlabelled void_class_label: An `int` that is used to represent empty or unlabelled
...@@ -52,6 +61,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -52,6 +61,8 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
void_instance_id: An `int` that is used to denote regions that are not void_instance_id: An `int` that is used to denote regions that are not
assigned to any thing class. That is, void_instance_id are assigned to assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions. both stuff regions and empty regions.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, image_info is used to rescale predictions.
**kwargs: additional kewargs arguments. **kwargs: additional kewargs arguments.
""" """
self._output_size = output_size self._output_size = output_size
...@@ -59,9 +70,12 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -59,9 +70,12 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
self._stuff_classes_offset = stuff_classes_offset self._stuff_classes_offset = stuff_classes_offset
self._mask_binarize_threshold = mask_binarize_threshold self._mask_binarize_threshold = mask_binarize_threshold
self._score_threshold = score_threshold self._score_threshold = score_threshold
self._things_overlap_threshold = things_overlap_threshold
self._stuff_area_threshold = stuff_area_threshold
self._things_class_label = things_class_label self._things_class_label = things_class_label
self._void_class_label = void_class_label self._void_class_label = void_class_label
self._void_instance_id = void_instance_id self._void_instance_id = void_instance_id
self._rescale_predictions = rescale_predictions
self._config_dict = { self._config_dict = {
'output_size': output_size, 'output_size': output_size,
...@@ -71,36 +85,15 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -71,36 +85,15 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
'score_threshold': score_threshold, 'score_threshold': score_threshold,
'things_class_label': things_class_label, 'things_class_label': things_class_label,
'void_class_label': void_class_label, 'void_class_label': void_class_label,
'void_instance_id': void_instance_id 'void_instance_id': void_instance_id,
'rescale_predictions': rescale_predictions
} }
super(PanopticSegmentationGenerator, self).__init__(**kwargs) super(PanopticSegmentationGenerator, self).__init__(**kwargs)
def _paste_mask(self, box, mask): def build(self, input_shape):
pasted_mask = tf.ones( grid_sampler = paste_masks.BilinearGridSampler(align_corners=False)
self._output_size + [1], dtype=mask.dtype) * self._void_class_label self._paste_masks_fn = paste_masks.PasteMasks(
output_size=self._output_size, grid_sampler=grid_sampler)
ymin = box[0]
xmin = box[1]
ymax = tf.clip_by_value(box[2] + 1, 0, self._output_size[0])
xmax = tf.clip_by_value(box[3] + 1, 0, self._output_size[1])
box_height = ymax - ymin
box_width = xmax - xmin
# resize mask to match the shape of the instance bounding box
resized_mask = tf.image.resize(
mask,
size=(box_height, box_width),
method='nearest')
# paste resized mask on a blank mask that matches image shape
pasted_mask = tf.raw_ops.TensorStridedSliceUpdate(
input=pasted_mask,
begin=[ymin, xmin],
end=[ymax, xmax],
strides=[1, 1],
value=resized_mask)
return pasted_mask
def _generate_panoptic_masks(self, boxes, scores, classes, detections_masks, def _generate_panoptic_masks(self, boxes, scores, classes, detections_masks,
segmentation_mask): segmentation_mask):
...@@ -132,6 +125,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -132,6 +125,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
- category_mask: A `tf.Tensor` for category masks. - category_mask: A `tf.Tensor` for category masks.
- instance_mask: A `tf.Tensor for instance masks. - instance_mask: A `tf.Tensor for instance masks.
""" """
# Offset stuff class predictions # Offset stuff class predictions
segmentation_mask = tf.where( segmentation_mask = tf.where(
tf.logical_or( tf.logical_or(
...@@ -161,6 +155,9 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -161,6 +155,9 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
loop_end_idx = tf.minimum( loop_end_idx = tf.minimum(
tf.cast(loop_end_idx, dtype=tf.int32), tf.cast(loop_end_idx, dtype=tf.int32),
self._max_num_detections) self._max_num_detections)
pasted_masks = self._paste_masks_fn((
detections_masks[:loop_end_idx],
boxes[:loop_end_idx]))
# add things segmentation to panoptic masks # add things segmentation to panoptic masks
for i in range(loop_end_idx): for i in range(loop_end_idx):
...@@ -168,9 +165,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -168,9 +165,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
# the overlaps are resolved based on confidence score # the overlaps are resolved based on confidence score
instance_idx = sorted_indices[i] instance_idx = sorted_indices[i]
pasted_mask = self._paste_mask( pasted_mask = pasted_masks[instance_idx]
box=boxes[instance_idx],
mask=detections_masks[instance_idx])
class_id = tf.cast(classes[instance_idx], dtype=tf.float32) class_id = tf.cast(classes[instance_idx], dtype=tf.float32)
...@@ -182,6 +177,19 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -182,6 +177,19 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
if not tf.reduce_sum(tf.cast(binary_mask, tf.float32)) > 0: if not tf.reduce_sum(tf.cast(binary_mask, tf.float32)) > 0:
continue continue
overlap = tf.logical_and(
binary_mask,
tf.not_equal(category_mask, self._void_class_label))
binary_mask_area = tf.reduce_sum(
tf.cast(binary_mask, dtype=tf.float32))
overlap_area = tf.reduce_sum(
tf.cast(overlap, dtype=tf.float32))
# skip instance that have a big enough overlap with instances with
# higer scores
if overlap_area / binary_mask_area > self._things_overlap_threshold:
continue
# fill empty regions in category_mask represented by # fill empty regions in category_mask represented by
# void_class_label with class_id of the instance. # void_class_label with class_id of the instance.
category_mask = tf.where( category_mask = tf.where(
...@@ -198,18 +206,25 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -198,18 +206,25 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
tf.ones_like(instance_mask) * tf.ones_like(instance_mask) *
tf.cast(instance_idx + 1, tf.float32), instance_mask) tf.cast(instance_idx + 1, tf.float32), instance_mask)
# add stuff segmentation labels to empty regions of category_mask. stuff_class_ids = tf.unique(tf.reshape(segmentation_mask, [-1])).y
# we ignore the pixels labelled as "things", since we get them from for stuff_class_id in stuff_class_ids:
# the instance masks. if stuff_class_id == self._things_class_label:
# TODO(srihari, arashwan): Support filtering stuff classes based on area. continue
stuff_mask = tf.logical_and(
tf.equal(segmentation_mask, stuff_class_id),
tf.equal(category_mask, self._void_class_label))
stuff_mask_area = tf.reduce_sum(
tf.cast(stuff_mask, dtype=tf.float32))
if stuff_mask_area < self._stuff_area_threshold:
continue
category_mask = tf.where( category_mask = tf.where(
tf.logical_and( stuff_mask,
tf.equal( tf.ones_like(category_mask) * stuff_class_id,
category_mask, self._void_class_label), category_mask)
tf.logical_and(
tf.not_equal(segmentation_mask, self._things_class_label),
tf.not_equal(segmentation_mask, self._void_class_label))),
segmentation_mask, category_mask)
results = { results = {
'category_mask': category_mask[:, :, 0], 'category_mask': category_mask[:, :, 0],
...@@ -217,19 +232,64 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -217,19 +232,64 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
} }
return results return results
def call(self, inputs): def _resize_and_pad_masks(self, mask, image_info):
"""Resizes masks to match the original image shape and pads to`output_size`.
Args:
mask: a padded mask tensor.
image_info: a tensor that holds information about original and
preprocessed images.
Returns:
resized and padded masks: tf.Tensor.
"""
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
mask = tf.image.resize(
mask,
rescale_size,
method='bilinear')
mask = tf.image.crop_to_bounding_box(
mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.pad_to_bounding_box(
mask, 0, 0, self._output_size[0], self._output_size[1])
return mask
def call(self, inputs: tf.Tensor, image_info: Optional[tf.Tensor] = None):
detections = inputs detections = inputs
batched_scores = detections['detection_scores'] batched_scores = detections['detection_scores']
batched_classes = detections['detection_classes'] batched_classes = detections['detection_classes']
batched_boxes = tf.cast(detections['detection_boxes'], dtype=tf.int32)
batched_detections_masks = tf.expand_dims( batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1) detections['detection_masks'], axis=-1)
batched_boxes = detections['detection_boxes']
batched_segmentation_masks = tf.cast(
detections['segmentation_outputs'], dtype=tf.float32)
if self._rescale_predictions:
scale = tf.tile(
tf.cast(image_info[:, 2:3, :], dtype=batched_boxes.dtype),
multiples=[1, 1, 2])
batched_boxes /= scale
batched_segmentation_masks = tf.map_fn(
fn=lambda x: self._resize_and_pad_masks(x[0], x[1]),
elems=(
batched_segmentation_masks,
image_info),
fn_output_signature=tf.float32,
parallel_iterations=32)
else:
batched_segmentation_masks = tf.image.resize( batched_segmentation_masks = tf.image.resize(
detections['segmentation_outputs'], batched_segmentation_masks,
size=self._output_size, size=self._output_size,
method='bilinear') method='bilinear')
batched_segmentation_masks = tf.expand_dims(tf.cast( batched_segmentation_masks = tf.expand_dims(tf.cast(
tf.argmax(batched_segmentation_masks, axis=-1), tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1) dtype=tf.float32), axis=-1)
...@@ -246,7 +306,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer): ...@@ -246,7 +306,7 @@ class PanopticSegmentationGenerator(tf.keras.layers.Layer):
fn_output_signature={ fn_output_signature={
'category_mask': tf.float32, 'category_mask': tf.float32,
'instance_mask': tf.float32 'instance_mask': tf.float32
}) }, parallel_iterations=32)
for k, v in panoptic_masks.items(): for k, v in panoptic_masks.items():
panoptic_masks[k] = tf.cast(v, dtype=tf.int32) panoptic_masks[k] = tf.cast(v, dtype=tf.int32)
......
...@@ -37,7 +37,8 @@ class PanopticSegmentationGeneratorTest( ...@@ -37,7 +37,8 @@ class PanopticSegmentationGeneratorTest(
'score_threshold': 0.005, 'score_threshold': 0.005,
'things_class_label': 1, 'things_class_label': 1,
'void_class_label': 0, 'void_class_label': 0,
'void_instance_id': -1 'void_instance_id': -1,
'rescale_predictions': False,
} }
generator = PANOPTIC_SEGMENTATION_GENERATOR(**config) generator = PANOPTIC_SEGMENTATION_GENERATOR(**config)
...@@ -79,7 +80,8 @@ class PanopticSegmentationGeneratorTest( ...@@ -79,7 +80,8 @@ class PanopticSegmentationGeneratorTest(
'score_threshold': 0.005, 'score_threshold': 0.005,
'things_class_label': 1, 'things_class_label': 1,
'void_class_label': 0, 'void_class_label': 0,
'void_instance_id': -1 'void_instance_id': -1,
'rescale_predictions': False,
} }
generator = PANOPTIC_SEGMENTATION_GENERATOR(**config) generator = PANOPTIC_SEGMENTATION_GENERATOR(**config)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definition for bilinear grid sampling and mask pasting layers."""
from typing import List
import tensorflow as tf
class BilinearGridSampler(tf.keras.layers.Layer):
"""Bilinear Grid Sampling layer."""
def __init__(self, align_corners: bool = False, **kwargs):
"""Generates panoptic segmentation masks.
Args:
align_corners: A `bool` bool, if True, the centers of the 4 corner
pixels of the input and output tensors are aligned, preserving the
values at the corner pixels.
**kwargs: Additional kwargs arguments.
"""
super(BilinearGridSampler, self).__init__(**kwargs)
self.align_corners = align_corners
self._config = {
'align_corners': align_corners
}
def build(self, input_shape):
features_shape, _, _ = input_shape
_, height, width, channels = features_shape.as_list()
self._height = height
self._width = width
self._channels = channels
def _valid_coordinates(self, x_coord, y_coord):
return tf.logical_and(
tf.logical_and(
tf.greater_equal(x_coord, 0),
tf.greater_equal(y_coord, 0)),
tf.logical_and(
tf.less(x_coord, self._width),
tf.less(y_coord, self._height)))
def _get_pixel(self, features, x_coord, y_coord):
x_coord = tf.cast(x_coord, dtype=tf.int32)
y_coord = tf.cast(y_coord, dtype=tf.int32)
clipped_x = tf.clip_by_value(x_coord, 0, self._width - 1)
clipped_y = tf.clip_by_value(y_coord, 0, self._height - 1)
batch_size, _, _, _ = features.shape.as_list()
if batch_size is None:
batch_size = tf.shape(features)[0]
batch_indices = tf.reshape(
tf.range(batch_size, dtype=tf.int32),
shape=[batch_size, 1, 1])
batch_indices = tf.tile(
batch_indices,
multiples=[1, x_coord.shape[1], x_coord.shape[2]])
indices = tf.cast(
tf.stack([batch_indices, clipped_y, clipped_x], axis=-1),
dtype=tf.int32)
gathered_pixels = tf.gather_nd(features, indices)
return tf.where(
tf.expand_dims(self._valid_coordinates(x_coord, y_coord), axis=-1),
gathered_pixels,
tf.zeros_like(gathered_pixels))
def call(self, inputs):
features, x_coord, y_coord = inputs
x_coord += 1
y_coord += 1
if self.align_corners:
x_coord = (x_coord * 0.5) * (self._width - 1)
y_coord = (y_coord * 0.5) * (self._height - 1)
else:
x_coord = (x_coord * self._width - 1) * 0.5
y_coord = (y_coord * self._height - 1) * 0.5
left = tf.floor(x_coord)
top = tf.floor(y_coord)
right = left + 1
bottom = top + 1
top_left = (right - x_coord) * (bottom - y_coord)
top_right = (x_coord - left) * (bottom - y_coord)
bottom_left = (right - x_coord) * (y_coord - top)
bottom_right = (x_coord - left) * (y_coord - top)
i_top_left = self._get_pixel(features, left, top)
i_top_right = self._get_pixel(features, right, top)
i_bottom_left = self._get_pixel(features, left, bottom)
i_bottom_right = self._get_pixel(features, right, bottom)
i_top_left *= tf.expand_dims(top_left, axis=-1)
i_top_right *= tf.expand_dims(top_right, axis=-1)
i_bottom_left *= tf.expand_dims(bottom_left, axis=-1)
i_bottom_right *= tf.expand_dims(bottom_right, axis=-1)
interpolated_features = tf.math.add_n(
[i_top_left, i_top_right, i_bottom_left, i_bottom_right])
return interpolated_features
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
class PasteMasks(tf.keras.layers.Layer):
"""Layer to paste instance masks."""
def __init__(self, output_size: List[int],
grid_sampler, **kwargs):
"""Resizes and pastes instance masks to match image size.
Args:
output_size: A `List` of integers that represent the height and width of
the output mask.
grid_sampler: A grid sampling layer. Currently only `BilinearGridSampler`
is supported.
**kwargs: Additional kwargs arguments.
"""
super(PasteMasks, self).__init__(**kwargs)
self._output_size = output_size
self._grid_sampler = grid_sampler
self._config = {
'output_size': output_size,
'grid_sampler': grid_sampler
}
def build(self, input_shape):
self._x_coords = tf.range(0, self._output_size[1], dtype=tf.float32)
self._y_coords = tf.range(0, self._output_size[0], dtype=tf.float32)
def call(self, inputs):
masks, boxes = inputs
y0, x0, y1, x1 = tf.split(boxes, 4, axis=1)
x_coords = tf.cast(self._x_coords, dtype=boxes.dtype)
y_coords = tf.cast(self._y_coords, dtype=boxes.dtype)
x_coords = (x_coords - x0) / (x1 - x0) * 2 - 1
y_coords = (y_coords - y0) / (y1 - y0) * 2 - 1
x_coords = tf.tile(
tf.expand_dims(x_coords, axis=1),
multiples=[1, self._output_size[0], 1])
y_coords = tf.tile(
tf.expand_dims(y_coords, axis=2),
multiples=[1, 1, self._output_size[1]])
pasted_masks = self._grid_sampler((masks, x_coords, y_coords))
return pasted_masks
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
...@@ -143,12 +143,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -143,12 +143,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
image_shape: tf.Tensor, image_info: tf.Tensor,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
gt_boxes: Optional[tf.Tensor] = None, gt_boxes: Optional[tf.Tensor] = None,
gt_classes: Optional[tf.Tensor] = None, gt_classes: Optional[tf.Tensor] = None,
gt_masks: Optional[tf.Tensor] = None, gt_masks: Optional[tf.Tensor] = None,
training: Optional[bool] = None) -> Mapping[str, tf.Tensor]: training: Optional[bool] = None) -> Mapping[str, tf.Tensor]:
image_shape = image_info[:, 1, :]
model_outputs = super(PanopticMaskRCNNModel, self).call( model_outputs = super(PanopticMaskRCNNModel, self).call(
images=images, images=images,
image_shape=image_shape, image_shape=image_shape,
...@@ -177,7 +178,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -177,7 +178,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
}) })
if not training and self.panoptic_segmentation_generator is not None: if not training and self.panoptic_segmentation_generator is not None:
panoptic_outputs = self.panoptic_segmentation_generator(model_outputs) panoptic_outputs = self.panoptic_segmentation_generator(
model_outputs, image_info=image_info)
model_outputs.update({'panoptic_outputs': panoptic_outputs}) model_outputs.update({'panoptic_outputs': panoptic_outputs})
return model_outputs return model_outputs
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -45,7 +44,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -45,7 +44,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
build_anchor_boxes=[True, False], build_anchor_boxes=[True, False],
shared_backbone=[True, False], shared_backbone=[True, False],
shared_decoder=[True, False], shared_decoder=[True, False],
is_training=[True, False])) is_training=[True,]))
def test_build_model(self, def test_build_model(self,
use_separable_conv, use_separable_conv,
build_anchor_boxes, build_anchor_boxes,
...@@ -53,23 +52,24 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -53,23 +52,24 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
shared_decoder, shared_decoder,
is_training=True): is_training=True):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 2
max_level = 7 max_level = 6
num_scales = 3 num_scales = 3
aspect_ratios = [1.0] aspect_ratios = [1.0]
anchor_size = 3 anchor_size = 3
resnet_model_id = 50 resnet_model_id = 50
segmentation_resnet_model_id = 50 segmentation_resnet_model_id = 50
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) aspp_decoder_level = 2
fpn_decoder_level = 3 fpn_decoder_level = 2
num_anchors_per_location = num_scales * len(aspect_ratios) num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 128 image_size = 128
images = np.random.rand(2, image_size, image_size, 3) images = tf.random.normal([2, image_size, image_size, 3])
image_shape = np.array([[image_size, image_size], [image_size, image_size]]) image_info = tf.convert_to_tensor(
[[[image_size, image_size], [image_size, image_size], [1, 1], [0, 0]],
[[image_size, image_size], [image_size, image_size], [1, 1], [0, 0]]])
shared_decoder = shared_decoder and shared_backbone shared_decoder = shared_decoder and shared_backbone
if build_anchor_boxes: if build_anchor_boxes or not is_training:
anchor_boxes = anchor.Anchor( anchor_boxes = anchor.Anchor(
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
...@@ -115,15 +115,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -115,15 +115,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -148,17 +153,17 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -148,17 +153,17 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
anchor_size=anchor_size) anchor_size=anchor_size)
gt_boxes = np.array( gt_boxes = tf.convert_to_tensor(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]], [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]], [[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32) dtype=tf.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32) gt_classes = tf.convert_to_tensor([[2, 1, -1], [1, -1, -1]], dtype=tf.int32)
gt_masks = np.ones((2, 3, 100, 100)) gt_masks = tf.ones((2, 3, 100, 100))
# Results will be checked in test_forward. # Results will be checked in test_forward.
_ = model( _ = model(
images, images,
image_shape, image_info,
anchor_boxes, anchor_boxes,
gt_boxes, gt_boxes,
gt_classes, gt_classes,
...@@ -179,23 +184,24 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -179,23 +184,24 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
shared_backbone, shared_decoder, shared_backbone, shared_decoder,
generate_panoptic_masks): generate_panoptic_masks):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 2
max_level = 4 max_level = 6
num_scales = 3 num_scales = 3
aspect_ratios = [1.0] aspect_ratios = [1.0]
anchor_size = 3 anchor_size = 3
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) aspp_decoder_level = 2
fpn_decoder_level = 3 fpn_decoder_level = 2
class_agnostic_bbox_pred = False class_agnostic_bbox_pred = False
cascade_class_ensemble = False cascade_class_ensemble = False
image_size = (256, 256) image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3) images = tf.random.normal([2, image_size[0], image_size[1], 3])
image_shape = np.array([[224, 100], [100, 224]]) image_info = tf.convert_to_tensor(
[[[224, 100], [224, 100], [1, 1], [0, 0]],
[[224, 100], [224, 100], [1, 1], [0, 0]]])
shared_decoder = shared_decoder and shared_backbone shared_decoder = shared_decoder and shared_backbone
with strategy.scope(): with strategy.scope():
...@@ -250,15 +256,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -250,15 +256,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -285,16 +296,17 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -285,16 +296,17 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
anchor_size=anchor_size) anchor_size=anchor_size)
gt_boxes = np.array( gt_boxes = tf.convert_to_tensor(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]], [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]], [[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32) dtype=tf.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32) gt_classes = tf.convert_to_tensor(
gt_masks = np.ones((2, 3, 100, 100)) [[2, 1, -1], [1, -1, -1]], dtype=tf.int32)
gt_masks = tf.ones((2, 3, 100, 100))
results = model( results = model(
images, images,
image_shape, image_info,
anchor_boxes, anchor_boxes,
gt_boxes, gt_boxes,
gt_classes, gt_classes,
...@@ -354,10 +366,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -354,10 +366,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
max_num_detections=100, max_num_detections=100,
stuff_classes_offset=90) stuff_classes_offset=90)
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) min_level = 2
fpn_decoder_level = 3 max_level = 6
aspp_decoder_level = 2
fpn_decoder_level = 2
shared_decoder = shared_decoder and shared_backbone shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2) mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -370,15 +383,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -370,15 +383,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -397,8 +415,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -397,8 +415,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone=segmentation_backbone, segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder, segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head, segmentation_head=segmentation_head,
min_level=3, min_level=min_level,
max_level=7, max_level=max_level,
num_scales=3, num_scales=3,
aspect_ratios=[1.0], aspect_ratios=[1.0],
anchor_size=3) anchor_size=3)
...@@ -433,10 +451,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -433,10 +451,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
max_num_detections=100, max_num_detections=100,
stuff_classes_offset=90) stuff_classes_offset=90)
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) min_level = 2
fpn_decoder_level = 3 max_level = 6
aspp_decoder_level = 2
fpn_decoder_level = 2
shared_decoder = shared_decoder and shared_backbone shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2) mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -449,15 +468,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -449,15 +468,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -476,8 +500,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -476,8 +500,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone=segmentation_backbone, segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder, segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head, segmentation_head=segmentation_head,
min_level=3, min_level=max_level,
max_level=7, max_level=max_level,
num_scales=3, num_scales=3,
aspect_ratios=[1.0], aspect_ratios=[1.0],
anchor_size=3) anchor_size=3)
......
...@@ -47,8 +47,8 @@ from official.vision.beta.serving import export_saved_model_lib ...@@ -47,8 +47,8 @@ from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('experiment', 'panoptic_maskrcnn_resnetfpn_coco', flags.DEFINE_string('experiment', 'panoptic_fpn_coco',
'experiment type, e.g. panoptic_maskrcnn_resnetfpn_coco') 'experiment type, e.g. panoptic_fpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.') flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.') flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string( flags.DEFINE_multi_string(
......
...@@ -88,14 +88,12 @@ class PanopticSegmentationModule(detection.DetectionModule): ...@@ -88,14 +88,12 @@ class PanopticSegmentationModule(detection.DetectionModule):
image_info_spec), image_info_spec),
parallel_iterations=32)) parallel_iterations=32))
input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that # To overcome keras.Model extra limitation to save a model with layers that
# have multiple inputs, we use `model.call` here to trigger the forward # have multiple inputs, we use `model.call` here to trigger the forward
# path. Note that, this disables some keras magics happens in `__call__`. # path. Note that, this disables some keras magics happens in `__call__`.
detections = self.model.call( detections = self.model.call(
images=images, images=images,
image_shape=input_image_shape, image_info=image_info,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
......
...@@ -70,9 +70,9 @@ class PanopticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -70,9 +70,9 @@ class PanopticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
return [example for b in range(batch_size)] return [example for b in range(batch_size)]
@parameterized.parameters( @parameterized.parameters(
('image_tensor', 'panoptic_maskrcnn_resnetfpn_coco'), ('image_tensor', 'panoptic_fpn_coco'),
('image_bytes', 'panoptic_maskrcnn_resnetfpn_coco'), ('image_bytes', 'panoptic_fpn_coco'),
('tf_example', 'panoptic_maskrcnn_resnetfpn_coco'), ('tf_example', 'panoptic_fpn_coco'),
) )
def test_export(self, input_type, experiment_name): def test_export(self, input_type, experiment_name):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
...@@ -96,15 +96,14 @@ class PanopticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -96,15 +96,14 @@ class PanopticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
processed_images, anchor_boxes, image_info = module._build_inputs( processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((128, 128, 3), dtype=tf.uint8)) tf.zeros((128, 128, 3), dtype=tf.uint8))
image_shape = image_info[1, :] image_info = tf.expand_dims(image_info, 0)
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0) processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items(): for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0) anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = module.model( expected_outputs = module.model(
images=processed_images, images=processed_images,
image_shape=image_shape, image_info=image_info,
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
outputs = detection_fn(tf.constant(images)) outputs = detection_fn(tf.constant(images))
...@@ -113,7 +112,7 @@ class PanopticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -113,7 +112,7 @@ class PanopticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
expected_outputs['num_detections'].numpy()) expected_outputs['num_detections'].numpy())
def test_build_model_fail_with_none_batch_size(self): def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('panoptic_maskrcnn_resnetfpn_coco') params = exp_factory.get_exp_config('panoptic_fpn_coco')
input_specs = tf.keras.layers.InputSpec(shape=[1, 128, 128, 3]) input_specs = tf.keras.layers.InputSpec(shape=[1, 128, 128, 3])
model = factory.build_panoptic_maskrcnn( model = factory.build_panoptic_maskrcnn(
input_specs=input_specs, model_config=params.task.model) input_specs=input_specs, model_config=params.task.model)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Panoptic MaskRCNN task definition.""" """Panoptic MaskRCNN task definition."""
from typing import Any, Dict, List, Mapping, Optional, Tuple from typing import Any, Dict, List, Mapping, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -178,6 +177,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -178,6 +177,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
ignore_label=params.semantic_segmentation_ignore_label, ignore_label=params.semantic_segmentation_ignore_label,
use_groundtruth_dimension=use_groundtruth_dimension, use_groundtruth_dimension=use_groundtruth_dimension,
top_k_percent_pixels=params.semantic_segmentation_top_k_percent_pixels) top_k_percent_pixels=params.semantic_segmentation_top_k_percent_pixels)
instance_segmentation_weight = params.instance_segmentation_weight
semantic_segmentation_weight = params.semantic_segmentation_weight semantic_segmentation_weight = params.semantic_segmentation_weight
losses = super(PanopticMaskRCNNTask, self).build_losses( losses = super(PanopticMaskRCNNTask, self).build_losses(
...@@ -190,7 +191,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -190,7 +191,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
labels['gt_segmentation_mask']) labels['gt_segmentation_mask'])
model_loss = ( model_loss = (
maskrcnn_loss + semantic_segmentation_weight * segmentation_loss) instance_segmentation_weight * maskrcnn_loss +
semantic_segmentation_weight * segmentation_loss)
total_loss = model_loss total_loss = model_loss
if aux_losses: if aux_losses:
...@@ -240,12 +242,18 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -240,12 +242,18 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
rescale_predictions = (not self.task_config.validation_data.parser rescale_predictions = (not self.task_config.validation_data.parser
.segmentation_resize_eval_groundtruth) .segmentation_resize_eval_groundtruth)
self.segmentation_perclass_iou_metric = segmentation_metrics.PerClassIoU( self.segmentation_perclass_iou_metric = segmentation_metrics.PerClassIoU(
name='per_class_iou', name='per_class_iou',
num_classes=num_segmentation_classes, num_classes=num_segmentation_classes,
rescale_predictions=rescale_predictions, rescale_predictions=rescale_predictions,
dtype=tf.float32) dtype=tf.float32)
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
self._process_iou_metric_on_cpu = True
else:
self._process_iou_metric_on_cpu = False
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
if not self.task_config.validation_data.parser.include_panoptic_masks: if not self.task_config.validation_data.parser.include_panoptic_masks:
raise ValueError('`include_panoptic_masks` should be set to True when' raise ValueError('`include_panoptic_masks` should be set to True when'
...@@ -256,7 +264,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -256,7 +264,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
ignored_label=pq_config.ignored_label, ignored_label=pq_config.ignored_label,
max_instances_per_category=pq_config.max_instances_per_category, max_instances_per_category=pq_config.max_instances_per_category,
offset=pq_config.offset, offset=pq_config.offset,
is_thing=pq_config.is_thing) is_thing=pq_config.is_thing,
rescale_predictions=pq_config.rescale_predictions)
return metrics return metrics
...@@ -282,7 +291,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -282,7 +291,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = model( outputs = model(
images, images,
image_shape=labels['image_info'][:, 1, :], image_info=labels['image_info'],
anchor_boxes=labels['anchor_boxes'], anchor_boxes=labels['anchor_boxes'],
gt_boxes=labels['gt_boxes'], gt_boxes=labels['gt_boxes'],
gt_classes=labels['gt_classes'], gt_classes=labels['gt_classes'],
...@@ -351,7 +360,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -351,7 +360,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
outputs = model( outputs = model(
images, images,
anchor_boxes=labels['anchor_boxes'], anchor_boxes=labels['anchor_boxes'],
image_shape=labels['image_info'][:, 1, :], image_info=labels['image_info'],
training=False) training=False)
logs = {self.loss: 0} logs = {self.loss: 0}
...@@ -369,18 +378,26 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -369,18 +378,26 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
'valid_masks': labels['groundtruths']['gt_segmentation_valid_mask'], 'valid_masks': labels['groundtruths']['gt_segmentation_valid_mask'],
'image_info': labels['image_info'] 'image_info': labels['image_info']
} }
if self._process_iou_metric_on_cpu:
logs.update({ logs.update({
self.coco_metric.name: (labels['groundtruths'], coco_model_outputs), self.coco_metric.name: (labels['groundtruths'], coco_model_outputs),
self.segmentation_perclass_iou_metric.name: ( self.segmentation_perclass_iou_metric.name: (
segmentation_labels, segmentation_labels,
outputs['segmentation_outputs']) outputs['segmentation_outputs'])
}) })
else:
self.segmentation_perclass_iou_metric.update_state(
segmentation_labels,
outputs['segmentation_outputs'])
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
pq_metric_labels = { pq_metric_labels = {
'category_mask': 'category_mask':
labels['groundtruths']['gt_panoptic_category_mask'], labels['groundtruths']['gt_panoptic_category_mask'],
'instance_mask': 'instance_mask':
labels['groundtruths']['gt_panoptic_instance_mask'] labels['groundtruths']['gt_panoptic_instance_mask'],
'image_info': labels['image_info']
} }
logs.update({ logs.update({
self.panoptic_quality_metric.name: self.panoptic_quality_metric.name:
...@@ -398,6 +415,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -398,6 +415,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.coco_metric.update_state( self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0], step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1]) step_outputs[self.coco_metric.name][1])
if self._process_iou_metric_on_cpu:
self.segmentation_perclass_iou_metric.update_state( self.segmentation_perclass_iou_metric.update_state(
step_outputs[self.segmentation_perclass_iou_metric.name][0], step_outputs[self.segmentation_perclass_iou_metric.name][0],
step_outputs[self.segmentation_perclass_iou_metric.name][1]) step_outputs[self.segmentation_perclass_iou_metric.name][1])
...@@ -424,7 +443,17 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -424,7 +443,17 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()}) result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
for k, value in self.panoptic_quality_metric.result().items(): report_per_class_metrics = self.task_config.panoptic_quality_evaluator.report_per_class_metrics
result['panoptic_quality/' + k] = value panoptic_quality_results = self.panoptic_quality_metric.result()
for k, value in panoptic_quality_results.items():
if k.endswith('per_class'):
if report_per_class_metrics:
for i, per_class_value in enumerate(value):
metric_key = 'panoptic_quality/{}/class_{}'.format(k, i)
result[metric_key] = per_class_value
else:
continue
else:
result['panoptic_quality/{}'.format(k)] = value
return result return result
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Panoptic MaskRCNN trainer. """Panoptic MaskRCNN trainer."""
All custom registry are imported from registry_imports. Here we use default
trainer so we directly call train.main. If you need to customize the trainer,
branch from `official/vision/beta/train.py` and make changes.
"""
from absl import app from absl import app
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment