Unverified Commit c127d527 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents 78657911 457bcb85
......@@ -64,6 +64,72 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_bn(self):
batch_norm_op = tf.keras.layers.BatchNormalization(
momentum=0.9,
epsilon=1.,
name='bn')
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
batch_norm_op=batch_norm_op,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]],
[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]]],
[[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]],
[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_activation(self):
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
activation_op=tf.nn.relu6,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.]]],
[[[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_temporal(self):
conv2d = movinet_layers.MobileConv2D(
filters=3,
......@@ -378,6 +444,35 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_stream_movinet_block_none_se(self):
block = movinet_layers.MovinetBlock(
out_filters=3,
expand_filters=6,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
se_type='none',
state_prefix='test',
)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, expected_states = block(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllEqual(list(expected_states.keys()), ['test_stream_buffer'])
def test_stream_classifier_head(self):
head = movinet_layers.Head(project_filters=5)
classifier_head = movinet_layers.ClassifierHead(
......
......@@ -99,6 +99,49 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_stream_nse(self):
"""Test if the backbone can be run in streaming mode w/o SE layer."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
se_type='none',
)
inputs = tf.ones([1, 5, 128, 128, 3])
init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = backbone({**states, 'image': frame})
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
# Check contents in the states dictionary.
state_keys = list(init_states.keys())
self.assertIn('state_head_pool_buffer', state_keys)
self.assertIn('state_head_pool_frame_count', state_keys)
state_keys.remove('state_head_pool_buffer')
state_keys.remove('state_head_pool_frame_count')
# From now on, there are only 'stream_buffer' for the convolutions.
for state_key in state_keys:
self.assertIn(
'stream_buffer', state_key,
msg=f'Expecting stream_buffer only, found {state_key}')
def test_movinet_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last')
......
......@@ -82,6 +82,9 @@ flags.DEFINE_string(
flags.DEFINE_string(
'activation', 'swish',
'The main activation to use across layers.')
flags.DEFINE_string(
'classifier_activation', 'swish',
'The classifier activation to use.')
flags.DEFINE_string(
'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.')
......@@ -124,11 +127,15 @@ def main(_) -> None:
# states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape]
# Override swish activation implementation to remove custom gradients
activation = FLAGS.activation
if activation == 'swish':
# Override swish activation implementation to remove custom gradients
activation = 'simple_swish'
classifier_activation = FLAGS.classifier_activation
if classifier_activation == 'swish':
classifier_activation = 'simple_swish'
backbone = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
......@@ -145,9 +152,7 @@ def main(_) -> None:
num_classes=FLAGS.num_classes,
output_states=FLAGS.causal,
input_specs=dict(image=input_specs),
# TODO(dankondratyuk): currently set to swish, but will need to
# re-train to use other activations.
activation='simple_swish')
activation=classifier_activation)
model.build(input_shape)
# Compile model to generate some internal Keras variables.
......
......@@ -18,7 +18,7 @@ from absl import flags
import tensorflow as tf
import tensorflow_hub as hub
from official.projects.movinet import export_saved_model
from official.projects.movinet.tools import export_saved_model
FLAGS = flags.FLAGS
......
......@@ -145,7 +145,7 @@ class Encoder(tf.keras.layers.Layer):
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
# https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
for i in range(self._num_layers):
encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu,
......
......@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision package definition."""
# Lint as: python3
# pylint: disable=unused-import
from official.vision.beta import configs
from official.vision.beta import tasks
......@@ -55,6 +55,20 @@ depth, label smoothing and dropout.
| ResNet-RS-350 | 256x256 | 164.3 | 83.7 | 96.7 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i256.tar.gz) |
| ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs420_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i320.tar.gz) |
#### Vision Transformer (ViT)
We support [ViT](https://arxiv.org/abs/2010.11929) and [DEIT](https://arxiv.org/abs/2012.12877) implementations in a TF
Vision
[project](https://github.com/tensorflow/models/tree/master/official/projects/vit). ViT models trained under the DEIT settings:
model | resolution | Top-1 | Top-5 |
--------- | :--------: | ----: | ----: |
ViT-s16 | 224x224 | 79.4 | 94.7 |
ViT-b16 | 224x224 | 81.8 | 95.8 |
ViT-l16 | 224x224 | 82.2 | 95.8 |
## Object Detection and Instance Segmentation
### Common Settings and Notes
......@@ -123,6 +137,7 @@ evaluated on [COCO](https://cocodataset.org/) val2017.
| Backbone | Resolution | Epochs | Params (M) | Box AP | Mask AP | Download
------------ | :--------: | -----: | ---------: | -----: | ------: | -------:
| SpineNet-49 | 640x640 | 500 | 56.4 | 46.4 | 40.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_cascadercnn_tpu.yaml)|
| SpineNet-96 | 1024x1024 | 500 | 70.8 | 50.9 | 43.8 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_cascadercnn_tpu.yaml)|
| SpineNet-143 | 1280x1280 | 500 | 94.9 | 51.9 | 45.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_cascadercnn_tpu.yaml)|
## Semantic Segmentation
......
# MobileNetV3Small ImageNet classification. 67.5% top-1 and 87.6% top-5 accuracy.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'mobilenet'
mobilenet:
model_id: 'MobileNetV3Small'
filter_size_scale: 1.0
norm_activation:
activation: 'relu'
norm_momentum: 0.997
norm_epsilon: 0.001
use_sync_bn: false
dropout_rate: 0.2
losses:
l2_weight_decay: 0.00001
one_hot: true
label_smoothing: 0.1
train_data:
input_path: 'imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 4096
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 312000 # 1000 epochs
validation_steps: 12
validation_interval: 312
steps_per_loop: 312 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
optimizer:
type: 'rmsprop'
rmsprop:
rho: 0.9
momentum: 0.9
epsilon: 0.002
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.01
decay_steps: 936 # 3 * steps_per_epoch
decay_rate: 0.99
staircase: true
ema:
average_decay: 0.9999
trainable_weights_only: false
warmup:
type: 'linear'
linear:
warmup_steps: 1560
warmup_learning_rate: 0.001
# --experiment_type=cascadercnn_spinenet_coco
# Expect to reach: box mAP: 51.9%, mask mAP: 45.0% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -8,12 +10,12 @@ task:
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
aug_scale_max: 2.5
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 3.0
anchor_size: 4.0
num_scales: 3
min_level: 3
max_level: 7
......
......@@ -714,7 +714,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
'use_depthwise': self._use_depthwise,
'use_residual': self._use_residual,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
'norm_epsilon': self._norm_epsilon,
'output_intermediate_endpoints': self._output_intermediate_endpoints
}
base_config = super(InvertedBottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......
......@@ -2284,8 +2284,9 @@ class MixupAndCutmix:
lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32))
dtype=(
images.dtype, tf.int32, tf.int32, tf.int32, tf.int32, images.dtype),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=images.dtype))
return images, labels, lam
......@@ -2294,7 +2295,8 @@ class MixupAndCutmix:
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0])
lam_cast = tf.cast(lam, dtype=images.dtype)
images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam)
......
......@@ -366,14 +366,19 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self):
@parameterized.named_parameters([
('float16_images', tf.float16),
('bfloat16_images', tf.bfloat16),
('float32_images', tf.float32),
])
class MixupAndCutmixTest(parameterized.TestCase, tf.test.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self, image_dtype):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
......@@ -388,12 +393,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_image(self):
def test_mixup_changes_image(self, image_dtype):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
......@@ -409,12 +414,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self):
def test_cutmix_changes_image(self, image_dtype):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
......
......@@ -25,6 +25,7 @@ from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deepmac_maskrcnn
SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel
......@@ -89,7 +90,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
@dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN):
class PanopticMaskRCNN(deepmac_maskrcnn.DeepMaskHeadRCNN):
"""Panoptic Mask R-CNN model config."""
segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
SEGMENTATION_MODEL(num_classes=2))
......
......@@ -17,10 +17,10 @@
import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory as models_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
......@@ -50,7 +50,7 @@ def build_panoptic_maskrcnn(
segmentation_config = model_config.segmentation_model
# Builds the maskrcnn model.
maskrcnn_model = models_factory.build_maskrcnn(
maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
......@@ -120,6 +120,7 @@ def build_panoptic_maskrcnn(
# Combines maskrcnn, and segmentation models to build panoptic segmentation
# model.
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone=maskrcnn_model.backbone,
decoder=maskrcnn_model.decoder,
......
......@@ -18,10 +18,10 @@ from typing import List, Mapping, Optional, Union
import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model
from official.vision.beta.projects.deepmac_maskrcnn.modeling import maskrcnn_model
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
class PanopticMaskRCNNModel(maskrcnn_model.DeepMaskRCNNModel):
"""The Panoptic Segmentation model."""
def __init__(
......@@ -49,7 +49,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras
anchor_size: Optional[float] = None,
use_gt_boxes_for_masks: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs):
"""Initializes the Panoptic Mask R-CNN model.
......@@ -94,6 +95,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level.
use_gt_boxes_for_masks: `bool`, whether to use only gt boxes for masks.
**kwargs: keyword arguments to be passed.
"""
super(PanopticMaskRCNNModel, self).__init__(
......@@ -115,6 +117,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
use_gt_boxes_for_masks=use_gt_boxes_for_masks,
**kwargs)
self._config_dict.update({
......
......@@ -97,6 +97,20 @@ class PanopticSegmentationModule(detection.DetectionModule):
anchor_boxes=anchor_boxes,
training=False)
detections.pop('rpn_boxes')
detections.pop('rpn_scores')
detections.pop('cls_outputs')
detections.pop('box_outputs')
detections.pop('backbone_features')
detections.pop('decoder_features')
# Normalize detection boxes to [0, 1]. Here we first map them to the
# original image size, then normalize them to [0, 1].
detections['detection_boxes'] = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]) /
tf.tile(image_info[:, 0:1, :], [1, 1, 2]))
if model_params.detection_generator.apply_nms:
final_outputs = {
'detection_boxes': detections['detection_boxes'],
......@@ -109,10 +123,15 @@ class PanopticSegmentationModule(detection.DetectionModule):
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
}
masks = detections['segmentation_outputs']
masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
classes = tf.math.argmax(masks, axis=-1)
scores = tf.nn.softmax(masks, axis=-1)
final_outputs.update({
'detection_masks': detections['detection_masks'],
'segmentation_outputs': detections['segmentation_outputs'],
'masks': masks,
'scores': scores,
'classes': classes,
'image_info': image_info
})
if model_params.generate_panoptic_masks:
......
......@@ -61,7 +61,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
def initialize(self, model: tf.keras.Model) -> None:
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint_modules:
if not self.task_config.init_checkpoint:
return
def _get_checkpoint_path(checkpoint_dir_or_file):
......
......@@ -34,7 +34,7 @@ import PIL.ImageFont as ImageFont
import six
import tensorflow as tf
from official.vision.beta.ops import box_ops
from official.vision.ops import box_ops
from official.vision.utils.object_detection import shape_utils
_TITLE_LEFT_MARGIN = 10
......
......@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
else:
raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}')
heatmap = tf.stop_gradient(heatmap)
heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch.
......
......@@ -30,6 +30,7 @@ if tf_version.is_tf2():
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
......@@ -50,7 +51,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start'
'color_consistency_warmup_steps', 'color_consistency_warmup_start',
'use_only_last_stage'
])
......@@ -140,33 +142,24 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Unknown network type {}'.format(name))
def crop_masks_within_boxes(masks, boxes, output_size):
"""Crops masks to lie tightly within the boxes.
Args:
masks: A [num_instances, height, width] float tensor of masks.
boxes: A [num_instances, 4] sized tensor of normalized bounding boxes.
output_size: The height and width of the output masks.
Returns:
masks: A [num_instances, output_size, output_size] tensor of masks which
are cropped to be tightly within the gives boxes and resized.
"""
masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[output_size, output_size])
return masks[:, 0, :, :, 0]
def _resize_instance_masks_non_empty(masks, shape):
"""Resize a non-empty tensor of masks to the given shape."""
height, width = shape
flattened_masks, batch_size, num_instances = flatten_first2_dims(masks)
flattened_masks = flattened_masks[:, :, :, tf.newaxis]
flattened_masks = tf.image.resize(
flattened_masks, (height, width),
method=tf.image.ResizeMethod.BILINEAR)
return unpack_first2_dims(
flattened_masks[:, :, :, 0], batch_size, num_instances)
def resize_instance_masks(masks, shape):
height, width = shape
masks_ex = masks[:, :, :, tf.newaxis]
masks_ex = tf.image.resize(masks_ex, (height, width),
method=tf.image.ResizeMethod.BILINEAR)
masks = masks_ex[:, :, :, 0]
return masks
batch_size, num_instances = tf.shape(masks)[0], tf.shape(masks)[1]
return tf.cond(
tf.shape(masks)[1] == 0,
lambda: tf.zeros((batch_size, num_instances, shape[0], shape[1])),
lambda: _resize_instance_masks_non_empty(masks, shape))
def filter_masked_classes(masked_class_ids, classes, weights, masks):
......@@ -175,94 +168,132 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
Args:
masked_class_ids: A list of class IDs allowed to have masks. These class IDs
are 1-indexed.
classes: A [num_instances, num_classes] float tensor containing the one-hot
encoded classes.
weights: A [num_instances] float tensor containing the weights of each
sample.
masks: A [num_instances, height, width] tensor containing the mask per
instance.
classes: A [batch_size, num_instances, num_classes] float tensor containing
the one-hot encoded classes.
weights: A [batch_size, num_instances] float tensor containing the weights
of each sample.
masks: A [batch_size, num_instances, height, width] tensor containing the
mask per instance.
Returns:
classes_filtered: A [num_instances, num_classes] float tensor containing the
one-hot encoded classes with classes not in masked_class_ids zeroed out.
weights_filtered: A [num_instances] float tensor containing the weights of
each sample with instances whose classes aren't in masked_class_ids
zeroed out.
masks_filtered: A [num_instances, height, width] tensor containing the mask
per instance with masks not belonging to masked_class_ids zeroed out.
classes_filtered: A [batch_size, num_instances, num_classes] float tensor
containing the one-hot encoded classes with classes not in
masked_class_ids zeroed out.
weights_filtered: A [batch_size, num_instances] float tensor containing the
weights of each sample with instances whose classes aren't in
masked_class_ids zeroed out.
masks_filtered: A [batch_size, num_instances, height, width] tensor
containing the mask per instance with masks not belonging to
masked_class_ids zeroed out.
"""
if len(masked_class_ids) == 0: # pylint:disable=g-explicit-length-test
return classes, weights, masks
if tf.shape(classes)[0] == 0:
if tf.shape(classes)[1] == 0:
return classes, weights, masks
masked_class_ids = tf.constant(np.array(masked_class_ids, dtype=np.int32))
label_id_offset = 1
masked_class_ids -= label_id_offset
class_ids = tf.argmax(classes, axis=1, output_type=tf.int32)
class_ids = tf.argmax(classes, axis=2, output_type=tf.int32)
matched_classes = tf.equal(
class_ids[:, tf.newaxis], masked_class_ids[tf.newaxis, :]
class_ids[:, :, tf.newaxis], masked_class_ids[tf.newaxis, tf.newaxis, :]
)
matched_classes = tf.reduce_any(matched_classes, axis=1)
matched_classes = tf.reduce_any(matched_classes, axis=2)
matched_classes = tf.cast(matched_classes, tf.float32)
return (
classes * matched_classes[:, tf.newaxis],
classes * matched_classes[:, :, tf.newaxis],
weights * matched_classes,
masks * matched_classes[:, tf.newaxis, tf.newaxis]
masks * matched_classes[:, :, tf.newaxis, tf.newaxis]
)
def crop_and_resize_feature_map(features, boxes, size):
"""Crop and resize regions from a single feature map given a set of boxes.
def flatten_first2_dims(tensor):
"""Flatten first 2 dimensions of a tensor.
Args:
features: A [H, W, C] float tensor.
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
tensor: A tensor with shape [M, N, ....]
Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized
features.
flattened_tensor: A tensor of shape [M * N, ...]
M: int, the length of the first dimension of the input.
N: int, the length of the second dimension of the input.
"""
return spatial_transform_ops.matmul_crop_and_resize(
features[tf.newaxis], boxes[tf.newaxis], [size, size])[0]
shape = tf.shape(tensor)
d1, d2, rest = shape[0], shape[1], shape[2:]
tensor = tf.reshape(
tensor, tf.concat([[d1 * d2], rest], axis=0))
return tensor, d1, d2
def unpack_first2_dims(tensor, dim1, dim2):
"""Unpack the flattened first dimension of the tensor into 2 dimensions.
Args:
tensor: A tensor of shape [dim1 * dim2, ...]
dim1: int, the size of the first dimension.
dim2: int, the size of the second dimension.
Returns:
unflattened_tensor: A tensor of shape [dim1, dim2, ...].
"""
shape = tf.shape(tensor)
result_shape = tf.concat([[dim1, dim2], shape[1:]], axis=0)
return tf.reshape(tensor, result_shape)
def crop_and_resize_instance_masks(masks, boxes, mask_size):
"""Crop and resize each mask according to the given boxes.
Args:
masks: A [N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes.
masks: A [B, N, H, W] float tensor.
boxes: A [B, N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks.
Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized
masks: A [B, N, mask_size, mask_size] float tensor of cropped and resized
instance masks.
"""
masks, batch_size, num_instances = flatten_first2_dims(masks)
boxes, _, _ = flatten_first2_dims(boxes)
cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size])
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
return cropped_masks
return unpack_first2_dims(cropped_masks, batch_size, num_instances)
def fill_boxes(boxes, height, width):
"""Fills the area included in the box."""
blist = box_list.BoxList(boxes)
blist = box_list_ops.to_absolute_coordinates(blist, height, width)
boxes = blist.get()
"""Fills the area included in the boxes with 1s.
Args:
boxes: A [batch_size, num_instances, 4] shapes float tensor of boxes given
in the normalized coordinate space.
height: int, height of the output image.
width: int, width of the output image.
Returns:
filled_boxes: A [batch_size, num_instances, height, width] shaped float
tensor with 1s in the area that falls inside each box.
"""
ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, tf.newaxis, tf.newaxis, :], 4, axis=3)
boxes[:, :, tf.newaxis, tf.newaxis, :], 4, axis=4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin *= height
ymax *= height
xmin *= width
xmax *= width
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
ygrid, xgrid = ygrid[tf.newaxis, :, :], xgrid[tf.newaxis, :, :]
ygrid, xgrid = (ygrid[tf.newaxis, tf.newaxis, :, :],
xgrid[tf.newaxis, tf.newaxis, :, :])
filled_boxes = tf.logical_and(
tf.logical_and(ygrid >= ymin, ygrid <= ymax),
......@@ -289,7 +320,7 @@ def embedding_projection(x, y):
return dot
def _get_2d_neighbors_kenel():
def _get_2d_neighbors_kernel():
"""Returns a conv. kernel that when applies generates 2D neighbors.
Returns:
......@@ -311,20 +342,34 @@ def generate_2d_neighbors(input_tensor, dilation=2):
following ops on TPU won't have to pad the last dimension to 128.
Args:
input_tensor: A float tensor of shape [height, width, channels].
input_tensor: A float tensor of shape [batch_size, height, width, channels].
dilation: int, the dilation factor for considering neighbors.
Returns:
output: A float tensor of all 8 2-D neighbors. of shape
[8, height, width, channels].
[8, batch_size, height, width, channels].
"""
input_tensor = tf.transpose(input_tensor, (2, 0, 1))
input_tensor = input_tensor[:, :, :, tf.newaxis]
kernel = _get_2d_neighbors_kenel()
# TODO(vighneshb) Minimize tranposing here to save memory.
# input_tensor: [B, C, H, W]
input_tensor = tf.transpose(input_tensor, (0, 3, 1, 2))
# input_tensor: [B, C, H, W, 1]
input_tensor = input_tensor[:, :, :, :, tf.newaxis]
# input_tensor: [B * C, H, W, 1]
input_tensor, batch_size, channels = flatten_first2_dims(input_tensor)
kernel = _get_2d_neighbors_kernel()
# output: [B * C, H, W, 8]
output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation,
padding='SAME')
return tf.transpose(output, [3, 1, 2, 0])
# output: [B, C, H, W, 8]
output = unpack_first2_dims(output, batch_size, channels)
# return: [8, B, H, W, C]
return tf.transpose(output, [4, 0, 2, 3, 1])
def gaussian_pixel_similarity(a, b, theta):
......@@ -339,12 +384,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
[1]: https://arxiv.org/abs/2012.02310
Args:
feature_map: A float tensor of shape [height, width, channels]
feature_map: A float tensor of shape [batch_size, height, width, channels]
dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian.
Returns:
dilated_similarity: A tensor of shape [8, height, width]
dilated_similarity: A tensor of shape [8, batch_size, height, width]
"""
neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis]
......@@ -358,21 +403,26 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2):
[1]: https://arxiv.org/abs/2012.02310
Args:
instance_masks: A float tensor of shape [num_instances, height, width]
instance_masks: A float tensor of shape [batch_size, num_instances,
height, width]
dilation: int, the dilation factor.
Returns:
dilated_same_label: A tensor of shape [8, num_instances, height, width]
dilated_same_label: A tensor of shape [8, batch_size, num_instances,
height, width]
"""
instance_masks = tf.transpose(instance_masks, (1, 2, 0))
# instance_masks: [batch_size, height, width, num_instances]
instance_masks = tf.transpose(instance_masks, (0, 2, 3, 1))
# neighbors: [8, batch_size, height, width, num_instances]
neighbors = generate_2d_neighbors(instance_masks, dilation)
# instance_masks = [1, batch_size, height, width, num_instances]
instance_masks = instance_masks[tf.newaxis]
same_mask_prob = ((instance_masks * neighbors) +
((1 - instance_masks) * (1 - neighbors)))
return tf.transpose(same_mask_prob, (0, 3, 1, 2))
return tf.transpose(same_mask_prob, (0, 1, 4, 2, 3))
def _per_pixel_single_conv(input_tensor, params, channels):
......@@ -722,6 +772,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
return tf.squeeze(out, axis=-1)
def _batch_gt_list(gt_list):
return tf.stack(gt_list, axis=0)
def deepmac_proto_to_params(deepmac_config):
"""Convert proto to named tuple."""
......@@ -765,7 +819,8 @@ def deepmac_proto_to_params(deepmac_config):
color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start
deepmac_config.color_consistency_warmup_start,
use_only_last_stage=deepmac_config.use_only_last_stage
)
......@@ -808,8 +863,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor,
......@@ -847,60 +900,62 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Get the input to the mask network, given bounding boxes.
Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in
normalized coordinates.
pixel_embedding: A [height, width, embedding_size] float tensor
containing spatial pixel embeddings.
boxes: A [batch_size, num_instances, 4] float tensor containing bounding
boxes in normalized coordinates.
pixel_embedding: A [batch_size, height, width, embedding_size] float
tensor containing spatial pixel embeddings.
Returns:
embedding: A [num_instances, mask_height, mask_width, embedding_size + 2]
float tensor containing the inputs to the mask network. For each
bounding box, we concatenate the normalized box coordinates to the
cropped pixel embeddings. If predict_full_resolution_masks is set,
mask_height and mask_width are the same as height and width of
pixel_embedding. If not, mask_height and mask_width are the same as
mask_size.
embedding: A [batch_size, num_instances, mask_height, mask_width,
embedding_size + 2] float tensor containing the inputs to the mask
network. For each bounding box, we concatenate the normalized box
coordinates to the cropped pixel embeddings. If
predict_full_resolution_masks is set, mask_height and mask_width are
the same as height and width of pixel_embedding. If not, mask_height
and mask_width are the same as mask_size.
"""
num_instances = tf.shape(boxes)[0]
batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks:
num_instances = tf.shape(boxes)[0]
pixel_embedding = pixel_embedding[tf.newaxis, :, :, :]
num_instances = tf.shape(boxes)[1]
pixel_embedding = pixel_embedding[:, tf.newaxis, :, :, :]
pixel_embeddings_processed = tf.tile(pixel_embedding,
[num_instances, 1, 1, 1])
[1, num_instances, 1, 1, 1])
image_shape = tf.shape(pixel_embeddings_processed)
image_height, image_width = image_shape[1], image_shape[2]
image_height, image_width = image_shape[2], image_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height),
tf.linspace(0.0, 1.0, image_width),
indexing='ij')
blist = box_list.BoxList(boxes)
ycenter, xcenter, _, _ = blist.get_center_coordinates_and_sizes()
y_grid = y_grid[tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, :, :]
ycenter = (boxes[:, :, 0] + boxes[:, :, 2]) / 2.0
xcenter = (boxes[:, :, 1] + boxes[:, :, 3]) / 2.0
y_grid = y_grid[tf.newaxis, tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, tf.newaxis, :, :]
y_grid -= ycenter[:, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=3)
y_grid -= ycenter[:, :, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, :, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=4)
else:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_processed = crop_and_resize_feature_map(
pixel_embedding, boxes, mask_size)
embeddings = spatial_transform_ops.matmul_crop_and_resize(
pixel_embedding, boxes, [mask_size, mask_size])
pixel_embeddings_processed = embeddings
mask_shape = tf.shape(pixel_embeddings_processed)
mask_height, mask_width = mask_shape[1], mask_shape[2]
mask_height, mask_width = mask_shape[2], mask_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
tf.linspace(-1.0, 1.0, mask_width),
indexing='ij')
coords = tf.stack([y_grid, x_grid], axis=2)
coords = coords[tf.newaxis, :, :, :]
coords = tf.tile(coords, [num_instances, 1, 1, 1])
coords = coords[tf.newaxis, tf.newaxis, :, :, :]
coords = tf.tile(coords, [batch_size, num_instances, 1, 1, 1])
if self._deepmac_params.use_xy:
return tf.concat([coords, pixel_embeddings_processed], axis=3)
return tf.concat([coords, pixel_embeddings_processed], axis=4)
else:
return pixel_embeddings_processed
......@@ -908,43 +963,94 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Return the instance embeddings from bounding box centers.
Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The
coordinates are in normalized input space.
instance_embedding: A [height, width, embedding_size] float tensor
containing the instance embeddings.
boxes: A [batch_size, num_instances, 4] float tensor holding bounding
boxes. The coordinates are in normalized input space.
instance_embedding: A [batch_size, height, width, embedding_size] float
tensor containing the instance embeddings.
Returns:
instance_embeddings: A [num_instances, embedding_size] shaped float tensor
containing the center embedding for each instance.
instance_embeddings: A [batch_size, num_instances, embedding_size]
shaped float tensor containing the center embedding for each instance.
"""
blist = box_list.BoxList(boxes)
output_height = tf.shape(instance_embedding)[0]
output_width = tf.shape(instance_embedding)[1]
blist_output = box_list_ops.to_absolute_coordinates(
blist, output_height, output_width, check_range=False)
(y_center_output, x_center_output,
_, _) = blist_output.get_center_coordinates_and_sizes()
center_coords_output = tf.stack([y_center_output, x_center_output], axis=1)
output_height = tf.cast(tf.shape(instance_embedding)[1], tf.float32)
output_width = tf.cast(tf.shape(instance_embedding)[2], tf.float32)
ymin = boxes[:, :, 0]
xmin = boxes[:, :, 1]
ymax = boxes[:, :, 2]
xmax = boxes[:, :, 3]
y_center_output = (ymin + ymax) * output_height / 2.0
x_center_output = (xmin + xmax) * output_width / 2.0
center_coords_output = tf.stack([y_center_output, x_center_output], axis=2)
center_coords_output_int = tf.cast(center_coords_output, tf.int32)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int,
batch_dims=1)
return center_latents
def predict(self, preprocessed_inputs, other_inputs):
prediction_dict = super(DeepMACMetaArch, self).predict(
preprocessed_inputs, other_inputs)
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
return prediction_dict
def _predict_mask_logits_from_embeddings(
self, pixel_embedding, instance_embedding, boxes):
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
mask_input, batch_size, num_instances = flatten_first2_dims(mask_input)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
instance_embeddings, _, _ = flatten_first2_dims(instance_embeddings)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_logits = unpack_first2_dims(
mask_logits, batch_size, num_instances)
return mask_logits
def _predict_mask_logits_from_gt_boxes(self, prediction_dict):
mask_logits_list = []
boxes = _batch_gt_list(self.groundtruth_lists(fields.BoxListFields.boxes))
instance_embedding_list = prediction_dict[INSTANCE_EMBEDDING]
pixel_embedding_list = prediction_dict[PIXEL_EMBEDDING]
if self._deepmac_params.use_only_last_stage:
instance_embedding_list = [instance_embedding_list[-1]]
pixel_embedding_list = [pixel_embedding_list[-1]]
for (instance_embedding, pixel_embedding) in zip(instance_embedding_list,
pixel_embedding_list):
mask_logits_list.append(
self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes))
return mask_logits_list
def _get_groundtruth_mask_output(self, boxes, masks):
"""Get the expected mask output for each box.
Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in
normalized coordinates.
masks: A [num_instances, height, width] float tensor containing binary
ground truth masks.
boxes: A [batch_size, num_instances, 4] float tensor containing bounding
boxes in normalized coordinates.
masks: A [batch_size, num_instances, height, width] float tensor
containing binary ground truth masks.
Returns:
masks: If predict_full_resolution_masks is set, masks are not resized
and the size of this tensor is [num_instances, input_height, input_width].
Otherwise, returns a tensor of size [num_instances, mask_size, mask_size].
and the size of this tensor is [batch_size, num_instances,
input_height, input_width]. Otherwise, returns a tensor of size
[batch_size, num_instances, mask_size, mask_size].
"""
mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks:
return masks
......@@ -957,9 +1063,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return cropped_masks
def _resize_logits_like_gt(self, logits, gt):
height, width = tf.shape(gt)[1], tf.shape(gt)[2]
height, width = tf.shape(gt)[2], tf.shape(gt)[3]
return resize_instance_masks(logits, (height, width))
def _aggregate_classification_loss(self, loss, gt, pred, method):
......@@ -1016,54 +1120,59 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else:
raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_per_instance_mask_prediction_loss(
def _compute_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks
mask_gt: The groundtruth mask.
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
mask_logits: A [batch_suze, num_instances, height, width] float tensor of
predicted masks
mask_gt: The groundtruth mask of same shape as mask_logits.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance.
"""
num_instances = tf.shape(boxes)[0]
batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1])
mask_logits = tf.reshape(mask_logits, [batch_size * num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [batch_size * num_instances, -1, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
return self._aggregate_classification_loss(
loss = self._aggregate_classification_loss(
loss, mask_gt, mask_logits, 'normalize_auto')
return tf.reshape(loss, [batch_size, num_instances])
def _compute_per_instance_box_consistency_loss(
def _compute_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits):
"""Compute the per-instance box consistency loss.
Args:
boxes_gt: A [num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [num_instances, 4] float tensor of augmented boxes,
to be used when using crop-and-resize based mask head.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
boxes_gt: A [batch_size, num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [batch_size, num_instances, 4] float tensor of
augmented boxes, to be used when using crop-and-resize based mask head.
mask_logits: A [batch_size, num_instances, height, width]
float tensor of predicted masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
loss: A [batch_size, num_instances] shaped tensor with the loss for
each instance in the batch.
"""
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2]
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, tf.newaxis]
shape = tf.shape(mask_logits)
batch_size, num_instances, height, width = (
shape[0], shape[1], shape[2], shape[3])
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, :, tf.newaxis]
if self._deepmac_params.predict_full_resolution_masks:
gt_crop = filled_boxes[:, :, :, 0]
pred_crop = mask_logits[:, :, :, 0]
gt_crop = filled_boxes[:, :, :, :, 0]
pred_crop = mask_logits[:, :, :, :, 0]
else:
gt_crop = crop_and_resize_instance_masks(
filled_boxes, boxes_for_crop, self._deepmac_params.mask_size)
......@@ -1071,7 +1180,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits, boxes_for_crop, self._deepmac_params.mask_size)
loss = 0.0
for axis in [1, 2]:
for axis in [2, 3]:
if self._deepmac_params.box_consistency_tightness:
pred_max_raw = tf.reduce_max(pred_crop, axis=axis)
......@@ -1083,44 +1192,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else:
pred_max = tf.reduce_max(pred_crop, axis=axis)
pred_max = pred_max[:, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis]
pred_max = pred_max[:, :, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, :, tf.newaxis]
flat_pred, batch_size, num_instances = flatten_first2_dims(pred_max)
flat_gt, _, _ = flatten_first2_dims(gt_max)
# We use flat tensors while calling loss functions because we
# want the loss per-instance to later multiply with the per-instance
# weight. Flattening the first 2 dims allows us to represent each instance
# in each batch as though they were samples in a larger batch.
raw_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max,
target_tensor=gt_max,
weights=tf.ones_like(pred_max))
prediction_tensor=flat_pred,
target_tensor=flat_gt,
weights=tf.ones_like(flat_pred))
loss += self._aggregate_classification_loss(
raw_loss, gt_max, pred_max,
agg_loss = self._aggregate_classification_loss(
raw_loss, flat_gt, flat_pred,
self._deepmac_params.box_consistency_loss_normalize)
loss += unpack_first2_dims(agg_loss, batch_size, num_instances)
return loss
def _compute_per_instance_color_consistency_loss(
def _compute_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits):
"""Compute the per-instance color consistency loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [height, width, 3] float tensor containing the
preprocessed image.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [batch_size, height, width, 3]
float tensor containing the preprocessed image.
mask_logits: A [batch_size, num_instances, height, width] float tensor of
predicted masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance fpr each sample in the batch.
"""
if not self._deepmac_params.predict_full_resolution_masks:
logging.info('Color consistency is not implemented with RoIAlign '
', i.e, fixed sized masks. Returning 0 loss.')
return tf.zeros(tf.shape(boxes)[0])
return tf.zeros(tf.shape(boxes)[:2])
dilation = self._deepmac_params.color_consistency_dilation
height, width = (tf.shape(preprocessed_image)[0],
tf.shape(preprocessed_image)[1])
height, width = (tf.shape(preprocessed_image)[1],
tf.shape(preprocessed_image)[2])
color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits)
......@@ -1132,20 +1250,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
color_similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold)
color_similarity_mask = tf.cast(
color_similarity_mask[:, tf.newaxis, :, :], tf.float32)
color_similarity_mask[:, :, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask *
tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width)
box_mask_expanded = box_mask[tf.newaxis, :, :, :]
box_mask_expanded = box_mask[tf.newaxis]
per_pixel_loss = per_pixel_loss * box_mask_expanded
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 2, 3])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2]))
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 3, 4])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[2, 3]))
loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
self._is_training):
tf.keras.backend.learning_phase()):
training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32)
......@@ -1157,56 +1275,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return loss
def _compute_per_instance_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding,
image):
def _compute_deepmac_losses(
self, boxes, masks_logits, masks_gt, image):
"""Returns the mask loss per instance.
Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The
coordinates are in normalized input space.
masks: A [num_instances, input_height, input_width] float tensor
containing the instance masks.
instance_embedding: A [output_height, output_width, embedding_size]
float tensor containing the instance embeddings.
pixel_embedding: optional [output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embeddings.
image: [output_height, output_width, channels] float tensor
boxes: A [batch_size, num_instances, 4] float tensor holding bounding
boxes. The coordinates are in normalized input space.
masks_logits: A [batch_size, num_instances, input_height, input_width]
float tensor containing the instance mask predictions in their logit
form.
masks_gt: A [batch_size, num_instances, input_height, input_width] float
tensor containing the groundtruth masks.
image: [batch_size, output_height, output_width, channels] float tensor
denoting the input image.
Returns:
mask_prediction_loss: A [num_instances] shaped float tensor containing the
mask loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the color consistency loss.
mask_prediction_loss: A [batch_size, num_instances] shaped float tensor
containing the mask loss for each instance in the batch.
box_consistency_loss: A [batch_size, num_instances] shaped float tensor
containing the box consistency loss for each instance in the batch.
box_consistency_loss: A [batch_size, num_instances] shaped float tensor
containing the color consistency loss in the batch.
"""
if tf.keras.backend.learning_phase():
boxes_for_crop = preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
boxes = tf.stop_gradient(boxes)
def jitter_func(boxes):
return preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
boxes_for_crop = tf.map_fn(jitter_func,
boxes, parallel_iterations=128)
else:
boxes_for_crop = boxes
mask_input = self._get_mask_head_input(
boxes_for_crop, pixel_embedding)
instance_embeddings = self._get_instance_embeddings(
boxes_for_crop, instance_embedding)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_gt = self._get_groundtruth_mask_output(boxes_for_crop, masks)
mask_gt = self._get_groundtruth_mask_output(
boxes_for_crop, masks_gt)
mask_prediction_loss = self._compute_per_instance_mask_prediction_loss(
boxes_for_crop, mask_logits, mask_gt)
mask_prediction_loss = self._compute_mask_prediction_loss(
boxes_for_crop, masks_logits, mask_gt)
box_consistency_loss = self._compute_per_instance_box_consistency_loss(
boxes, boxes_for_crop, mask_logits)
box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, masks_logits)
color_consistency_loss = self._compute_per_instance_color_consistency_loss(
boxes, image, mask_logits)
color_consistency_loss = self._compute_color_consistency_loss(
boxes, image, masks_logits)
return {
DEEP_MASK_ESTIMATION: mask_prediction_loss,
......@@ -1224,7 +1339,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image)
def _compute_instance_masks_loss(self, prediction_dict):
def _compute_masks_loss(self, prediction_dict):
"""Computes the mask loss.
Args:
......@@ -1236,10 +1351,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
Returns:
loss_dict: A dict mapping string (loss names) to scalar floats.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids)
......@@ -1248,8 +1359,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for loss_name in MASK_LOSSES:
loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0])
height, width = prediction_shape[1], prediction_shape[2]
prediction_shape = tf.shape(prediction_dict[MASK_LOGITS_GT_BOXES][0])
height, width = prediction_shape[2], prediction_shape[3]
preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width))
......@@ -1258,42 +1369,46 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
# TODO(vighneshb) See if we can save memory by only using the final
# prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip(
prediction_dict[INSTANCE_EMBEDDING],
prediction_dict[PIXEL_EMBEDDING]):
# Iterate over samples in batch
# TODO(vighneshb) find out how autograph is handling this. Converting
# to a single op may give speed improvements
for i, (boxes, weights, classes, masks) in enumerate(
zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)):
# TODO(vighneshb) Add sub-sampling back if required.
classes, valid_mask_weights, masks = filter_masked_classes(
allowed_masked_classes_ids, classes, weights, masks)
sample_loss_dict = self._compute_per_instance_deepmac_losses(
boxes, masks, instance_pred[i], pixel_pred[i], image[i])
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= weights
num_instances = tf.maximum(tf.reduce_sum(weights), 1.0)
num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) /
num_instances_allowed)
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) /
num_instances)
batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING])
return dict((key, loss / float(batch_size * num_predictions))
gt_boxes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.boxes))
gt_weights = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.weights))
gt_masks = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.masks))
gt_classes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.classes))
mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES]
for mask_logits in mask_logits_list:
# TODO(vighneshb) Add sub-sampling back if required.
_, valid_mask_weights, gt_masks = filter_masked_classes(
allowed_masked_classes_ids, gt_classes,
gt_weights, gt_masks)
sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image)
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= gt_weights
num_instances = tf.maximum(tf.reduce_sum(gt_weights), 1.0)
num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) /
num_instances_allowed)
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) /
num_instances)
num_predictions = len(mask_logits_list)
return dict((key, loss / float(num_predictions))
for key, loss in loss_dict.items())
def loss(self, prediction_dict, true_image_shapes, scope=None):
......@@ -1302,7 +1417,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict, true_image_shapes, scope)
if self._deepmac_params is not None:
mask_loss_dict = self._compute_instance_masks_loss(
mask_loss_dict = self._compute_masks_loss(
prediction_dict=prediction_dict)
for loss_name in MASK_LOSSES:
......@@ -1363,50 +1478,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_size] containing binary per-box instance masks.
"""
def process(elems):
boxes, instance_embedding, pixel_embedding = elems
return self._postprocess_sample(boxes, instance_embedding,
pixel_embedding)
max_instances = self._center_params.max_box_predictions
return tf.map_fn(process, [boxes_output_stride, instance_embedding,
pixel_embedding],
dtype=tf.float32, parallel_iterations=max_instances)
def _postprocess_sample(self, boxes_output_stride,
instance_embedding, pixel_embedding):
"""Post process masks for a single sample.
Args:
boxes_output_stride: A [num_instances, 4] float tensor containing
bounding boxes in the absolute output space.
instance_embedding: A [output_height, output_width, embedding_size]
float tensor containing instance embeddings.
pixel_embedding: A [batch_size, output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embedding.
Returns:
masks: A float tensor of size [num_instances, mask_height, mask_width]
containing binary per-box instance masks. If
predict_full_resolution_masks is set, the masks will be resized to
postprocess_crop_size. Otherwise, mask_height=mask_width=mask_size
"""
height, width = (tf.shape(instance_embedding)[0],
tf.shape(instance_embedding)[1])
height, width = (tf.shape(instance_embedding)[1],
tf.shape(instance_embedding)[2])
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
blist = box_list.BoxList(boxes_output_stride)
blist = box_list_ops.to_normalized_coordinates(
blist, height, width, check_range=False)
boxes = blist.get()
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
ymin, xmin, ymax, xmax = tf.unstack(boxes_output_stride, axis=2)
ymin /= height
ymax /= height
xmin /= width
xmax /= width
boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_logits = self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes)
# TODO(vighneshb) Explore sweeping mask thresholds.
......@@ -1416,7 +1499,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
height *= self._stride
width *= self._stride
mask_logits = resize_instance_masks(mask_logits, (height, width))
mask_logits = crop_masks_within_boxes(
mask_logits = crop_and_resize_instance_masks(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
masks_prob = tf.nn.sigmoid(mask_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment