Unverified Commit c127d527 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents 78657911 457bcb85
...@@ -64,6 +64,72 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -64,6 +64,72 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
def test_mobile_conv2d_bn(self):
batch_norm_op = tf.keras.layers.BatchNormalization(
momentum=0.9,
epsilon=1.,
name='bn')
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
batch_norm_op=batch_norm_op,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]],
[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]]],
[[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]],
[[8.48528, 8.48528, 8.48528],
[8.48528, 8.48528, 8.48528]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_activation(self):
conv2d = movinet_layers.MobileConv2D(
filters=3,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
kernel_initializer='ones',
use_bias=False,
use_depthwise=False,
use_temporal=False,
use_buffered_input=True,
activation_op=tf.nn.relu6,
)
inputs = tf.ones([1, 2, 2, 2, 3])
predicted = conv2d(inputs)
expected = tf.constant(
[[[[[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.]]],
[[[6., 6., 6.],
[6., 6., 6.]],
[[6., 6., 6.],
[6., 6., 6.]]]]])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_mobile_conv2d_temporal(self): def test_mobile_conv2d_temporal(self):
conv2d = movinet_layers.MobileConv2D( conv2d = movinet_layers.MobileConv2D(
filters=3, filters=3,
...@@ -378,6 +444,35 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -378,6 +444,35 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
def test_stream_movinet_block_none_se(self):
block = movinet_layers.MovinetBlock(
out_filters=3,
expand_filters=6,
kernel_size=(3, 3, 3),
strides=(1, 2, 2),
causal=True,
se_type='none',
state_prefix='test',
)
inputs = tf.range(4, dtype=tf.float32) + 1.
inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
expected, expected_states = block(inputs)
for num_splits in [1, 2, 4]:
frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
states = {}
predicted = []
for frame in frames:
x, states = block(frame, states=states)
predicted.append(x)
predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllEqual(list(expected_states.keys()), ['test_stream_buffer'])
def test_stream_classifier_head(self): def test_stream_classifier_head(self):
head = movinet_layers.Head(project_filters=5) head = movinet_layers.Head(project_filters=5)
classifier_head = movinet_layers.ClassifierHead( classifier_head = movinet_layers.ClassifierHead(
......
...@@ -99,6 +99,49 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -99,6 +99,49 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5) self.assertAllClose(predicted, expected, 1e-5, 1e-5)
def test_movinet_stream_nse(self):
"""Test if the backbone can be run in streaming mode w/o SE layer."""
tf.keras.backend.set_image_data_format('channels_last')
backbone = movinet.Movinet(
model_id='a0',
causal=True,
use_external_states=True,
se_type='none',
)
inputs = tf.ones([1, 5, 128, 128, 3])
init_states = backbone.init_states(tf.shape(inputs))
expected_endpoints, _ = backbone({**init_states, 'image': inputs})
frames = tf.split(inputs, inputs.shape[1], axis=1)
states = init_states
for frame in frames:
output, states = backbone({**states, 'image': frame})
predicted_endpoints = output
predicted = predicted_endpoints['head']
# The expected final output is simply the mean across frames
expected = expected_endpoints['head']
expected = tf.reduce_mean(expected, 1, keepdims=True)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected, 1e-5, 1e-5)
# Check contents in the states dictionary.
state_keys = list(init_states.keys())
self.assertIn('state_head_pool_buffer', state_keys)
self.assertIn('state_head_pool_frame_count', state_keys)
state_keys.remove('state_head_pool_buffer')
state_keys.remove('state_head_pool_frame_count')
# From now on, there are only 'stream_buffer' for the convolutions.
for state_key in state_keys:
self.assertIn(
'stream_buffer', state_key,
msg=f'Expecting stream_buffer only, found {state_key}')
def test_movinet_2plus1d_stream(self): def test_movinet_2plus1d_stream(self):
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
......
...@@ -82,6 +82,9 @@ flags.DEFINE_string( ...@@ -82,6 +82,9 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
'activation', 'swish', 'activation', 'swish',
'The main activation to use across layers.') 'The main activation to use across layers.')
flags.DEFINE_string(
'classifier_activation', 'swish',
'The classifier activation to use.')
flags.DEFINE_string( flags.DEFINE_string(
'gating_activation', 'sigmoid', 'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.') 'The gating activation to use in squeeze-excitation layers.')
...@@ -124,11 +127,15 @@ def main(_) -> None: ...@@ -124,11 +127,15 @@ def main(_) -> None:
# states. These dimensions can be set to `None` once the model is built. # states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape] input_shape = [1 if s is None else s for s in input_specs.shape]
# Override swish activation implementation to remove custom gradients
activation = FLAGS.activation activation = FLAGS.activation
if activation == 'swish': if activation == 'swish':
# Override swish activation implementation to remove custom gradients
activation = 'simple_swish' activation = 'simple_swish'
classifier_activation = FLAGS.classifier_activation
if classifier_activation == 'swish':
classifier_activation = 'simple_swish'
backbone = movinet.Movinet( backbone = movinet.Movinet(
model_id=FLAGS.model_id, model_id=FLAGS.model_id,
causal=FLAGS.causal, causal=FLAGS.causal,
...@@ -145,9 +152,7 @@ def main(_) -> None: ...@@ -145,9 +152,7 @@ def main(_) -> None:
num_classes=FLAGS.num_classes, num_classes=FLAGS.num_classes,
output_states=FLAGS.causal, output_states=FLAGS.causal,
input_specs=dict(image=input_specs), input_specs=dict(image=input_specs),
# TODO(dankondratyuk): currently set to swish, but will need to activation=classifier_activation)
# re-train to use other activations.
activation='simple_swish')
model.build(input_shape) model.build(input_shape)
# Compile model to generate some internal Keras variables. # Compile model to generate some internal Keras variables.
......
...@@ -18,7 +18,7 @@ from absl import flags ...@@ -18,7 +18,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.projects.movinet import export_saved_model from official.projects.movinet.tools import export_saved_model
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -145,7 +145,7 @@ class Encoder(tf.keras.layers.Layer): ...@@ -145,7 +145,7 @@ class Encoder(tf.keras.layers.Layer):
self._encoder_layers = [] self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
for i in range(self._num_layers): for i in range(self._num_layers):
encoder_layer = nn_blocks.TransformerEncoderBlock( encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu, inner_activation=activations.gelu,
......
...@@ -12,3 +12,8 @@ ...@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Vision package definition."""
# Lint as: python3
# pylint: disable=unused-import
from official.vision.beta import configs
from official.vision.beta import tasks
...@@ -55,6 +55,20 @@ depth, label smoothing and dropout. ...@@ -55,6 +55,20 @@ depth, label smoothing and dropout.
| ResNet-RS-350 | 256x256 | 164.3 | 83.7 | 96.7 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i256.tar.gz) | | ResNet-RS-350 | 256x256 | 164.3 | 83.7 | 96.7 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i256.tar.gz) |
| ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs420_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i320.tar.gz) | | ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/image_classification/imagenet_resnetrs420_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i320.tar.gz) |
#### Vision Transformer (ViT)
We support [ViT](https://arxiv.org/abs/2010.11929) and [DEIT](https://arxiv.org/abs/2012.12877) implementations in a TF
Vision
[project](https://github.com/tensorflow/models/tree/master/official/projects/vit). ViT models trained under the DEIT settings:
model | resolution | Top-1 | Top-5 |
--------- | :--------: | ----: | ----: |
ViT-s16 | 224x224 | 79.4 | 94.7 |
ViT-b16 | 224x224 | 81.8 | 95.8 |
ViT-l16 | 224x224 | 82.2 | 95.8 |
## Object Detection and Instance Segmentation ## Object Detection and Instance Segmentation
### Common Settings and Notes ### Common Settings and Notes
...@@ -123,6 +137,7 @@ evaluated on [COCO](https://cocodataset.org/) val2017. ...@@ -123,6 +137,7 @@ evaluated on [COCO](https://cocodataset.org/) val2017.
| Backbone | Resolution | Epochs | Params (M) | Box AP | Mask AP | Download | Backbone | Resolution | Epochs | Params (M) | Box AP | Mask AP | Download
------------ | :--------: | -----: | ---------: | -----: | ------: | -------: ------------ | :--------: | -----: | ---------: | -----: | ------: | -------:
| SpineNet-49 | 640x640 | 500 | 56.4 | 46.4 | 40.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_cascadercnn_tpu.yaml)| | SpineNet-49 | 640x640 | 500 | 56.4 | 46.4 | 40.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet49_cascadercnn_tpu.yaml)|
| SpineNet-96 | 1024x1024 | 500 | 70.8 | 50.9 | 43.8 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet96_cascadercnn_tpu.yaml)|
| SpineNet-143 | 1280x1280 | 500 | 94.9 | 51.9 | 45.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_cascadercnn_tpu.yaml)| | SpineNet-143 | 1280x1280 | 500 | 94.9 | 51.9 | 45.0 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/maskrcnn/coco_spinenet143_cascadercnn_tpu.yaml)|
## Semantic Segmentation ## Semantic Segmentation
......
# MobileNetV3Small ImageNet classification. 67.5% top-1 and 87.6% top-5 accuracy.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'mobilenet'
mobilenet:
model_id: 'MobileNetV3Small'
filter_size_scale: 1.0
norm_activation:
activation: 'relu'
norm_momentum: 0.997
norm_epsilon: 0.001
use_sync_bn: false
dropout_rate: 0.2
losses:
l2_weight_decay: 0.00001
one_hot: true
label_smoothing: 0.1
train_data:
input_path: 'imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 4096
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 312000 # 1000 epochs
validation_steps: 12
validation_interval: 312
steps_per_loop: 312 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
optimizer:
type: 'rmsprop'
rmsprop:
rho: 0.9
momentum: 0.9
epsilon: 0.002
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.01
decay_steps: 936 # 3 * steps_per_epoch
decay_rate: 0.99
staircase: true
ema:
average_decay: 0.9999
trainable_weights_only: false
warmup:
type: 'linear'
linear:
warmup_steps: 1560
warmup_learning_rate: 0.001
# --experiment_type=cascadercnn_spinenet_coco
# Expect to reach: box mAP: 51.9%, mask mAP: 45.0% on COCO
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16' mixed_precision_dtype: 'bfloat16'
...@@ -8,12 +10,12 @@ task: ...@@ -8,12 +10,12 @@ task:
parser: parser:
aug_rand_hflip: true aug_rand_hflip: true
aug_scale_min: 0.1 aug_scale_min: 0.1
aug_scale_max: 2.0 aug_scale_max: 2.5
losses: losses:
l2_weight_decay: 0.00004 l2_weight_decay: 0.00004
model: model:
anchor: anchor:
anchor_size: 3.0 anchor_size: 4.0
num_scales: 3 num_scales: 3
min_level: 3 min_level: 3
max_level: 7 max_level: 7
......
...@@ -714,7 +714,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -714,7 +714,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
'use_depthwise': self._use_depthwise, 'use_depthwise': self._use_depthwise,
'use_residual': self._use_residual, 'use_residual': self._use_residual,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'output_intermediate_endpoints': self._output_intermediate_endpoints
} }
base_config = super(InvertedBottleneckBlock, self).get_config() base_config = super(InvertedBottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -2284,8 +2284,9 @@ class MixupAndCutmix: ...@@ -2284,8 +2284,9 @@ class MixupAndCutmix:
lambda x: _fill_rectangle(*x), lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2, (images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])), cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32), dtype=(
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32)) images.dtype, tf.int32, tf.int32, tf.int32, tf.int32, images.dtype),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=images.dtype))
return images, labels, lam return images, labels, lam
...@@ -2294,7 +2295,8 @@ class MixupAndCutmix: ...@@ -2294,7 +2295,8 @@ class MixupAndCutmix:
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha, lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
labels.shape) labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1]) lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0]) lam_cast = tf.cast(lam, dtype=images.dtype)
images = lam_cast * images + (1. - lam_cast) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam) return images, labels, tf.squeeze(lam)
......
...@@ -366,14 +366,19 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -366,14 +366,19 @@ class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
self.assertNotEqual(0, tf.reduce_max(aug_image)) self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters([
('float16_images', tf.float16),
def test_mixup_and_cutmix_smoothes_labels(self): ('bfloat16_images', tf.bfloat16),
('float32_images', tf.float32),
])
class MixupAndCutmixTest(parameterized.TestCase, tf.test.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self, image_dtype):
batch_size = 12 batch_size = 12
num_classes = 1000 num_classes = 1000
label_smoothing = 0.1 label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size) labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix( augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing) num_classes=num_classes, label_smoothing=label_smoothing)
...@@ -388,12 +393,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): ...@@ -388,12 +393,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance 1e4) # With tolerance
def test_mixup_changes_image(self): def test_mixup_changes_image(self, image_dtype):
batch_size = 12 batch_size = 12
num_classes = 1000 num_classes = 1000
label_smoothing = 0.1 label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size) labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix( augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes) mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
...@@ -409,12 +414,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): ...@@ -409,12 +414,12 @@ class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
1e4) # With tolerance 1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images)) self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self): def test_cutmix_changes_image(self, image_dtype):
batch_size = 12 batch_size = 12
num_classes = 1000 num_classes = 1000
label_smoothing = 0.1 label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) images = tf.random.normal((batch_size, 224, 224, 3), dtype=image_dtype)
labels = tf.range(batch_size) labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix( augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes) mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
......
...@@ -25,6 +25,7 @@ from official.modeling import optimization ...@@ -25,6 +25,7 @@ from official.modeling import optimization
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deepmac_maskrcnn
SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel SEGMENTATION_MODEL = semantic_segmentation.SemanticSegmentationModel
...@@ -89,7 +90,7 @@ class PanopticSegmentationGenerator(hyperparams.Config): ...@@ -89,7 +90,7 @@ class PanopticSegmentationGenerator(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN): class PanopticMaskRCNN(deepmac_maskrcnn.DeepMaskHeadRCNN):
"""Panoptic Mask R-CNN model config.""" """Panoptic Mask R-CNN model config."""
segmentation_model: semantic_segmentation.SemanticSegmentationModel = ( segmentation_model: semantic_segmentation.SemanticSegmentationModel = (
SEGMENTATION_MODEL(num_classes=2)) SEGMENTATION_MODEL(num_classes=2))
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory as models_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
...@@ -50,7 +50,7 @@ def build_panoptic_maskrcnn( ...@@ -50,7 +50,7 @@ def build_panoptic_maskrcnn(
segmentation_config = model_config.segmentation_model segmentation_config = model_config.segmentation_model
# Builds the maskrcnn model. # Builds the maskrcnn model.
maskrcnn_model = models_factory.build_maskrcnn( maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
...@@ -120,6 +120,7 @@ def build_panoptic_maskrcnn( ...@@ -120,6 +120,7 @@ def build_panoptic_maskrcnn(
# Combines maskrcnn, and segmentation models to build panoptic segmentation # Combines maskrcnn, and segmentation models to build panoptic segmentation
# model. # model.
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone=maskrcnn_model.backbone, backbone=maskrcnn_model.backbone,
decoder=maskrcnn_model.decoder, decoder=maskrcnn_model.decoder,
......
...@@ -18,10 +18,10 @@ from typing import List, Mapping, Optional, Union ...@@ -18,10 +18,10 @@ from typing import List, Mapping, Optional, Union
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling import maskrcnn_model from official.vision.beta.projects.deepmac_maskrcnn.modeling import maskrcnn_model
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): class PanopticMaskRCNNModel(maskrcnn_model.DeepMaskRCNNModel):
"""The Panoptic Segmentation model.""" """The Panoptic Segmentation model."""
def __init__( def __init__(
...@@ -49,7 +49,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -49,7 +49,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
max_level: Optional[int] = None, max_level: Optional[int] = None,
num_scales: Optional[int] = None, num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None, aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras anchor_size: Optional[float] = None,
use_gt_boxes_for_masks: bool = False, # pytype: disable=annotation-type-mismatch # typed-keras
**kwargs): **kwargs):
"""Initializes the Panoptic Mask R-CNN model. """Initializes the Panoptic Mask R-CNN model.
...@@ -94,6 +95,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -94,6 +95,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level. aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level.
anchor_size: A number representing the scale of size of the base anchor to anchor_size: A number representing the scale of size of the base anchor to
the feature stride 2^level. the feature stride 2^level.
use_gt_boxes_for_masks: `bool`, whether to use only gt boxes for masks.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(PanopticMaskRCNNModel, self).__init__( super(PanopticMaskRCNNModel, self).__init__(
...@@ -115,6 +117,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -115,6 +117,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
num_scales=num_scales, num_scales=num_scales,
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
anchor_size=anchor_size, anchor_size=anchor_size,
use_gt_boxes_for_masks=use_gt_boxes_for_masks,
**kwargs) **kwargs)
self._config_dict.update({ self._config_dict.update({
......
...@@ -97,6 +97,20 @@ class PanopticSegmentationModule(detection.DetectionModule): ...@@ -97,6 +97,20 @@ class PanopticSegmentationModule(detection.DetectionModule):
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
detections.pop('rpn_boxes')
detections.pop('rpn_scores')
detections.pop('cls_outputs')
detections.pop('box_outputs')
detections.pop('backbone_features')
detections.pop('decoder_features')
# Normalize detection boxes to [0, 1]. Here we first map them to the
# original image size, then normalize them to [0, 1].
detections['detection_boxes'] = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]) /
tf.tile(image_info[:, 0:1, :], [1, 1, 2]))
if model_params.detection_generator.apply_nms: if model_params.detection_generator.apply_nms:
final_outputs = { final_outputs = {
'detection_boxes': detections['detection_boxes'], 'detection_boxes': detections['detection_boxes'],
...@@ -109,10 +123,15 @@ class PanopticSegmentationModule(detection.DetectionModule): ...@@ -109,10 +123,15 @@ class PanopticSegmentationModule(detection.DetectionModule):
'decoded_boxes': detections['decoded_boxes'], 'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores'] 'decoded_box_scores': detections['decoded_box_scores']
} }
masks = detections['segmentation_outputs']
masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
classes = tf.math.argmax(masks, axis=-1)
scores = tf.nn.softmax(masks, axis=-1)
final_outputs.update({ final_outputs.update({
'detection_masks': detections['detection_masks'], 'detection_masks': detections['detection_masks'],
'segmentation_outputs': detections['segmentation_outputs'], 'masks': masks,
'scores': scores,
'classes': classes,
'image_info': image_info 'image_info': image_info
}) })
if model_params.generate_panoptic_masks: if model_params.generate_panoptic_masks:
......
...@@ -61,7 +61,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -61,7 +61,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
def initialize(self, model: tf.keras.Model) -> None: def initialize(self, model: tf.keras.Model) -> None:
"""Loading pretrained checkpoint.""" """Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint_modules: if not self.task_config.init_checkpoint:
return return
def _get_checkpoint_path(checkpoint_dir_or_file): def _get_checkpoint_path(checkpoint_dir_or_file):
......
...@@ -34,7 +34,7 @@ import PIL.ImageFont as ImageFont ...@@ -34,7 +34,7 @@ import PIL.ImageFont as ImageFont
import six import six
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import box_ops from official.vision.ops import box_ops
from official.vision.utils.object_detection import shape_utils from official.vision.utils.object_detection import shape_utils
_TITLE_LEFT_MARGIN = 10 _TITLE_LEFT_MARGIN = 10
......
...@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object): ...@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
else: else:
raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}') raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}')
heatmap = tf.stop_gradient(heatmap)
heatmaps.append(heatmap) heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch. # Return the stacked heatmaps over the batch.
......
...@@ -30,6 +30,7 @@ if tf_version.is_tf2(): ...@@ -30,6 +30,7 @@ if tf_version.is_tf2():
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING' INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING' PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation' DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency' DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency' DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
...@@ -50,7 +51,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [ ...@@ -50,7 +51,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'box_consistency_loss_weight', 'color_consistency_threshold', 'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight', 'color_consistency_dilation', 'color_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness', 'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start' 'color_consistency_warmup_steps', 'color_consistency_warmup_start',
'use_only_last_stage'
]) ])
...@@ -140,33 +142,24 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None): ...@@ -140,33 +142,24 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Unknown network type {}'.format(name)) raise ValueError('Unknown network type {}'.format(name))
def crop_masks_within_boxes(masks, boxes, output_size): def _resize_instance_masks_non_empty(masks, shape):
"""Crops masks to lie tightly within the boxes. """Resize a non-empty tensor of masks to the given shape."""
height, width = shape
Args: flattened_masks, batch_size, num_instances = flatten_first2_dims(masks)
masks: A [num_instances, height, width] float tensor of masks. flattened_masks = flattened_masks[:, :, :, tf.newaxis]
boxes: A [num_instances, 4] sized tensor of normalized bounding boxes. flattened_masks = tf.image.resize(
output_size: The height and width of the output masks. flattened_masks, (height, width),
method=tf.image.ResizeMethod.BILINEAR)
Returns: return unpack_first2_dims(
masks: A [num_instances, output_size, output_size] tensor of masks which flattened_masks[:, :, :, 0], batch_size, num_instances)
are cropped to be tightly within the gives boxes and resized.
"""
masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[output_size, output_size])
return masks[:, 0, :, :, 0]
def resize_instance_masks(masks, shape): def resize_instance_masks(masks, shape):
height, width = shape batch_size, num_instances = tf.shape(masks)[0], tf.shape(masks)[1]
masks_ex = masks[:, :, :, tf.newaxis] return tf.cond(
masks_ex = tf.image.resize(masks_ex, (height, width), tf.shape(masks)[1] == 0,
method=tf.image.ResizeMethod.BILINEAR) lambda: tf.zeros((batch_size, num_instances, shape[0], shape[1])),
masks = masks_ex[:, :, :, 0] lambda: _resize_instance_masks_non_empty(masks, shape))
return masks
def filter_masked_classes(masked_class_ids, classes, weights, masks): def filter_masked_classes(masked_class_ids, classes, weights, masks):
...@@ -175,94 +168,132 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks): ...@@ -175,94 +168,132 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
Args: Args:
masked_class_ids: A list of class IDs allowed to have masks. These class IDs masked_class_ids: A list of class IDs allowed to have masks. These class IDs
are 1-indexed. are 1-indexed.
classes: A [num_instances, num_classes] float tensor containing the one-hot classes: A [batch_size, num_instances, num_classes] float tensor containing
encoded classes. the one-hot encoded classes.
weights: A [num_instances] float tensor containing the weights of each weights: A [batch_size, num_instances] float tensor containing the weights
sample. of each sample.
masks: A [num_instances, height, width] tensor containing the mask per masks: A [batch_size, num_instances, height, width] tensor containing the
instance. mask per instance.
Returns: Returns:
classes_filtered: A [num_instances, num_classes] float tensor containing the classes_filtered: A [batch_size, num_instances, num_classes] float tensor
one-hot encoded classes with classes not in masked_class_ids zeroed out. containing the one-hot encoded classes with classes not in
weights_filtered: A [num_instances] float tensor containing the weights of masked_class_ids zeroed out.
each sample with instances whose classes aren't in masked_class_ids weights_filtered: A [batch_size, num_instances] float tensor containing the
zeroed out. weights of each sample with instances whose classes aren't in
masks_filtered: A [num_instances, height, width] tensor containing the mask masked_class_ids zeroed out.
per instance with masks not belonging to masked_class_ids zeroed out. masks_filtered: A [batch_size, num_instances, height, width] tensor
containing the mask per instance with masks not belonging to
masked_class_ids zeroed out.
""" """
if len(masked_class_ids) == 0: # pylint:disable=g-explicit-length-test if len(masked_class_ids) == 0: # pylint:disable=g-explicit-length-test
return classes, weights, masks return classes, weights, masks
if tf.shape(classes)[0] == 0: if tf.shape(classes)[1] == 0:
return classes, weights, masks return classes, weights, masks
masked_class_ids = tf.constant(np.array(masked_class_ids, dtype=np.int32)) masked_class_ids = tf.constant(np.array(masked_class_ids, dtype=np.int32))
label_id_offset = 1 label_id_offset = 1
masked_class_ids -= label_id_offset masked_class_ids -= label_id_offset
class_ids = tf.argmax(classes, axis=1, output_type=tf.int32) class_ids = tf.argmax(classes, axis=2, output_type=tf.int32)
matched_classes = tf.equal( matched_classes = tf.equal(
class_ids[:, tf.newaxis], masked_class_ids[tf.newaxis, :] class_ids[:, :, tf.newaxis], masked_class_ids[tf.newaxis, tf.newaxis, :]
) )
matched_classes = tf.reduce_any(matched_classes, axis=1) matched_classes = tf.reduce_any(matched_classes, axis=2)
matched_classes = tf.cast(matched_classes, tf.float32) matched_classes = tf.cast(matched_classes, tf.float32)
return ( return (
classes * matched_classes[:, tf.newaxis], classes * matched_classes[:, :, tf.newaxis],
weights * matched_classes, weights * matched_classes,
masks * matched_classes[:, tf.newaxis, tf.newaxis] masks * matched_classes[:, :, tf.newaxis, tf.newaxis]
) )
def crop_and_resize_feature_map(features, boxes, size): def flatten_first2_dims(tensor):
"""Crop and resize regions from a single feature map given a set of boxes. """Flatten first 2 dimensions of a tensor.
Args: Args:
features: A [H, W, C] float tensor. tensor: A tensor with shape [M, N, ....]
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
Returns: Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized flattened_tensor: A tensor of shape [M * N, ...]
features. M: int, the length of the first dimension of the input.
N: int, the length of the second dimension of the input.
""" """
return spatial_transform_ops.matmul_crop_and_resize( shape = tf.shape(tensor)
features[tf.newaxis], boxes[tf.newaxis], [size, size])[0] d1, d2, rest = shape[0], shape[1], shape[2:]
tensor = tf.reshape(
tensor, tf.concat([[d1 * d2], rest], axis=0))
return tensor, d1, d2
def unpack_first2_dims(tensor, dim1, dim2):
"""Unpack the flattened first dimension of the tensor into 2 dimensions.
Args:
tensor: A tensor of shape [dim1 * dim2, ...]
dim1: int, the size of the first dimension.
dim2: int, the size of the second dimension.
Returns:
unflattened_tensor: A tensor of shape [dim1, dim2, ...].
"""
shape = tf.shape(tensor)
result_shape = tf.concat([[dim1, dim2], shape[1:]], axis=0)
return tf.reshape(tensor, result_shape)
def crop_and_resize_instance_masks(masks, boxes, mask_size): def crop_and_resize_instance_masks(masks, boxes, mask_size):
"""Crop and resize each mask according to the given boxes. """Crop and resize each mask according to the given boxes.
Args: Args:
masks: A [N, H, W] float tensor. masks: A [B, N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes. boxes: A [B, N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks. mask_size: int, the size of the output masks.
Returns: Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized masks: A [B, N, mask_size, mask_size] float tensor of cropped and resized
instance masks. instance masks.
""" """
masks, batch_size, num_instances = flatten_first2_dims(masks)
boxes, _, _ = flatten_first2_dims(boxes)
cropped_masks = spatial_transform_ops.matmul_crop_and_resize( cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :], masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size]) [mask_size, mask_size])
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4]) cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
return unpack_first2_dims(cropped_masks, batch_size, num_instances)
return cropped_masks
def fill_boxes(boxes, height, width): def fill_boxes(boxes, height, width):
"""Fills the area included in the box.""" """Fills the area included in the boxes with 1s.
blist = box_list.BoxList(boxes)
blist = box_list_ops.to_absolute_coordinates(blist, height, width) Args:
boxes = blist.get() boxes: A [batch_size, num_instances, 4] shapes float tensor of boxes given
in the normalized coordinate space.
height: int, height of the output image.
width: int, width of the output image.
Returns:
filled_boxes: A [batch_size, num_instances, height, width] shaped float
tensor with 1s in the area that falls inside each box.
"""
ymin, xmin, ymax, xmax = tf.unstack( ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, tf.newaxis, tf.newaxis, :], 4, axis=3) boxes[:, :, tf.newaxis, tf.newaxis, :], 4, axis=4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin *= height
ymax *= height
xmin *= width
xmax *= width
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij') ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32) ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
ygrid, xgrid = ygrid[tf.newaxis, :, :], xgrid[tf.newaxis, :, :] ygrid, xgrid = (ygrid[tf.newaxis, tf.newaxis, :, :],
xgrid[tf.newaxis, tf.newaxis, :, :])
filled_boxes = tf.logical_and( filled_boxes = tf.logical_and(
tf.logical_and(ygrid >= ymin, ygrid <= ymax), tf.logical_and(ygrid >= ymin, ygrid <= ymax),
...@@ -289,7 +320,7 @@ def embedding_projection(x, y): ...@@ -289,7 +320,7 @@ def embedding_projection(x, y):
return dot return dot
def _get_2d_neighbors_kenel(): def _get_2d_neighbors_kernel():
"""Returns a conv. kernel that when applies generates 2D neighbors. """Returns a conv. kernel that when applies generates 2D neighbors.
Returns: Returns:
...@@ -311,20 +342,34 @@ def generate_2d_neighbors(input_tensor, dilation=2): ...@@ -311,20 +342,34 @@ def generate_2d_neighbors(input_tensor, dilation=2):
following ops on TPU won't have to pad the last dimension to 128. following ops on TPU won't have to pad the last dimension to 128.
Args: Args:
input_tensor: A float tensor of shape [height, width, channels]. input_tensor: A float tensor of shape [batch_size, height, width, channels].
dilation: int, the dilation factor for considering neighbors. dilation: int, the dilation factor for considering neighbors.
Returns: Returns:
output: A float tensor of all 8 2-D neighbors. of shape output: A float tensor of all 8 2-D neighbors. of shape
[8, height, width, channels]. [8, batch_size, height, width, channels].
""" """
input_tensor = tf.transpose(input_tensor, (2, 0, 1))
input_tensor = input_tensor[:, :, :, tf.newaxis]
kernel = _get_2d_neighbors_kenel() # TODO(vighneshb) Minimize tranposing here to save memory.
# input_tensor: [B, C, H, W]
input_tensor = tf.transpose(input_tensor, (0, 3, 1, 2))
# input_tensor: [B, C, H, W, 1]
input_tensor = input_tensor[:, :, :, :, tf.newaxis]
# input_tensor: [B * C, H, W, 1]
input_tensor, batch_size, channels = flatten_first2_dims(input_tensor)
kernel = _get_2d_neighbors_kernel()
# output: [B * C, H, W, 8]
output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation, output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation,
padding='SAME') padding='SAME')
return tf.transpose(output, [3, 1, 2, 0]) # output: [B, C, H, W, 8]
output = unpack_first2_dims(output, batch_size, channels)
# return: [8, B, H, W, C]
return tf.transpose(output, [4, 0, 2, 3, 1])
def gaussian_pixel_similarity(a, b, theta): def gaussian_pixel_similarity(a, b, theta):
...@@ -339,12 +384,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0): ...@@ -339,12 +384,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
[1]: https://arxiv.org/abs/2012.02310 [1]: https://arxiv.org/abs/2012.02310
Args: Args:
feature_map: A float tensor of shape [height, width, channels] feature_map: A float tensor of shape [batch_size, height, width, channels]
dilation: int, the dilation factor. dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian. theta: The denominator while taking difference inside the gaussian.
Returns: Returns:
dilated_similarity: A tensor of shape [8, height, width] dilated_similarity: A tensor of shape [8, batch_size, height, width]
""" """
neighbors = generate_2d_neighbors(feature_map, dilation) neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis] feature_map = feature_map[tf.newaxis]
...@@ -358,21 +403,26 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2): ...@@ -358,21 +403,26 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2):
[1]: https://arxiv.org/abs/2012.02310 [1]: https://arxiv.org/abs/2012.02310
Args: Args:
instance_masks: A float tensor of shape [num_instances, height, width] instance_masks: A float tensor of shape [batch_size, num_instances,
height, width]
dilation: int, the dilation factor. dilation: int, the dilation factor.
Returns: Returns:
dilated_same_label: A tensor of shape [8, num_instances, height, width] dilated_same_label: A tensor of shape [8, batch_size, num_instances,
height, width]
""" """
instance_masks = tf.transpose(instance_masks, (1, 2, 0)) # instance_masks: [batch_size, height, width, num_instances]
instance_masks = tf.transpose(instance_masks, (0, 2, 3, 1))
# neighbors: [8, batch_size, height, width, num_instances]
neighbors = generate_2d_neighbors(instance_masks, dilation) neighbors = generate_2d_neighbors(instance_masks, dilation)
# instance_masks = [1, batch_size, height, width, num_instances]
instance_masks = instance_masks[tf.newaxis] instance_masks = instance_masks[tf.newaxis]
same_mask_prob = ((instance_masks * neighbors) + same_mask_prob = ((instance_masks * neighbors) +
((1 - instance_masks) * (1 - neighbors))) ((1 - instance_masks) * (1 - neighbors)))
return tf.transpose(same_mask_prob, (0, 3, 1, 2)) return tf.transpose(same_mask_prob, (0, 1, 4, 2, 3))
def _per_pixel_single_conv(input_tensor, params, channels): def _per_pixel_single_conv(input_tensor, params, channels):
...@@ -722,6 +772,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -722,6 +772,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
return tf.squeeze(out, axis=-1) return tf.squeeze(out, axis=-1)
def _batch_gt_list(gt_list):
return tf.stack(gt_list, axis=0)
def deepmac_proto_to_params(deepmac_config): def deepmac_proto_to_params(deepmac_config):
"""Convert proto to named tuple.""" """Convert proto to named tuple."""
...@@ -765,7 +819,8 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -765,7 +819,8 @@ def deepmac_proto_to_params(deepmac_config):
color_consistency_warmup_steps= color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps, deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start= color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start deepmac_config.color_consistency_warmup_start,
use_only_last_stage=deepmac_config.use_only_last_stage
) )
...@@ -808,8 +863,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -808,8 +863,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
f'pixel_embedding_dim({pixel_embedding_dim}) ' f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).') f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
super(DeepMACMetaArch, self).__init__( super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries, is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor, num_classes=num_classes, feature_extractor=feature_extractor,
...@@ -847,60 +900,62 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -847,60 +900,62 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Get the input to the mask network, given bounding boxes. """Get the input to the mask network, given bounding boxes.
Args: Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in boxes: A [batch_size, num_instances, 4] float tensor containing bounding
normalized coordinates. boxes in normalized coordinates.
pixel_embedding: A [height, width, embedding_size] float tensor pixel_embedding: A [batch_size, height, width, embedding_size] float
containing spatial pixel embeddings. tensor containing spatial pixel embeddings.
Returns: Returns:
embedding: A [num_instances, mask_height, mask_width, embedding_size + 2] embedding: A [batch_size, num_instances, mask_height, mask_width,
float tensor containing the inputs to the mask network. For each embedding_size + 2] float tensor containing the inputs to the mask
bounding box, we concatenate the normalized box coordinates to the network. For each bounding box, we concatenate the normalized box
cropped pixel embeddings. If predict_full_resolution_masks is set, coordinates to the cropped pixel embeddings. If
mask_height and mask_width are the same as height and width of predict_full_resolution_masks is set, mask_height and mask_width are
pixel_embedding. If not, mask_height and mask_width are the same as the same as height and width of pixel_embedding. If not, mask_height
mask_size. and mask_width are the same as mask_size.
""" """
num_instances = tf.shape(boxes)[0] batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_size = self._deepmac_params.mask_size mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
num_instances = tf.shape(boxes)[0] num_instances = tf.shape(boxes)[1]
pixel_embedding = pixel_embedding[tf.newaxis, :, :, :] pixel_embedding = pixel_embedding[:, tf.newaxis, :, :, :]
pixel_embeddings_processed = tf.tile(pixel_embedding, pixel_embeddings_processed = tf.tile(pixel_embedding,
[num_instances, 1, 1, 1]) [1, num_instances, 1, 1, 1])
image_shape = tf.shape(pixel_embeddings_processed) image_shape = tf.shape(pixel_embeddings_processed)
image_height, image_width = image_shape[1], image_shape[2] image_height, image_width = image_shape[2], image_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height), y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height),
tf.linspace(0.0, 1.0, image_width), tf.linspace(0.0, 1.0, image_width),
indexing='ij') indexing='ij')
blist = box_list.BoxList(boxes) ycenter = (boxes[:, :, 0] + boxes[:, :, 2]) / 2.0
ycenter, xcenter, _, _ = blist.get_center_coordinates_and_sizes() xcenter = (boxes[:, :, 1] + boxes[:, :, 3]) / 2.0
y_grid = y_grid[tf.newaxis, :, :] y_grid = y_grid[tf.newaxis, tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, :, :] x_grid = x_grid[tf.newaxis, tf.newaxis, :, :]
y_grid -= ycenter[:, tf.newaxis, tf.newaxis] y_grid -= ycenter[:, :, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, tf.newaxis, tf.newaxis] x_grid -= xcenter[:, :, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=3) coords = tf.stack([y_grid, x_grid], axis=4)
else: else:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False. # TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_processed = crop_and_resize_feature_map( embeddings = spatial_transform_ops.matmul_crop_and_resize(
pixel_embedding, boxes, mask_size) pixel_embedding, boxes, [mask_size, mask_size])
pixel_embeddings_processed = embeddings
mask_shape = tf.shape(pixel_embeddings_processed) mask_shape = tf.shape(pixel_embeddings_processed)
mask_height, mask_width = mask_shape[1], mask_shape[2] mask_height, mask_width = mask_shape[2], mask_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height), y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
tf.linspace(-1.0, 1.0, mask_width), tf.linspace(-1.0, 1.0, mask_width),
indexing='ij') indexing='ij')
coords = tf.stack([y_grid, x_grid], axis=2) coords = tf.stack([y_grid, x_grid], axis=2)
coords = coords[tf.newaxis, :, :, :] coords = coords[tf.newaxis, tf.newaxis, :, :, :]
coords = tf.tile(coords, [num_instances, 1, 1, 1]) coords = tf.tile(coords, [batch_size, num_instances, 1, 1, 1])
if self._deepmac_params.use_xy: if self._deepmac_params.use_xy:
return tf.concat([coords, pixel_embeddings_processed], axis=3) return tf.concat([coords, pixel_embeddings_processed], axis=4)
else: else:
return pixel_embeddings_processed return pixel_embeddings_processed
...@@ -908,43 +963,94 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -908,43 +963,94 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Return the instance embeddings from bounding box centers. """Return the instance embeddings from bounding box centers.
Args: Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The boxes: A [batch_size, num_instances, 4] float tensor holding bounding
coordinates are in normalized input space. boxes. The coordinates are in normalized input space.
instance_embedding: A [height, width, embedding_size] float tensor instance_embedding: A [batch_size, height, width, embedding_size] float
containing the instance embeddings. tensor containing the instance embeddings.
Returns: Returns:
instance_embeddings: A [num_instances, embedding_size] shaped float tensor instance_embeddings: A [batch_size, num_instances, embedding_size]
containing the center embedding for each instance. shaped float tensor containing the center embedding for each instance.
""" """
blist = box_list.BoxList(boxes)
output_height = tf.shape(instance_embedding)[0] output_height = tf.cast(tf.shape(instance_embedding)[1], tf.float32)
output_width = tf.shape(instance_embedding)[1] output_width = tf.cast(tf.shape(instance_embedding)[2], tf.float32)
ymin = boxes[:, :, 0]
blist_output = box_list_ops.to_absolute_coordinates( xmin = boxes[:, :, 1]
blist, output_height, output_width, check_range=False) ymax = boxes[:, :, 2]
(y_center_output, x_center_output, xmax = boxes[:, :, 3]
_, _) = blist_output.get_center_coordinates_and_sizes()
center_coords_output = tf.stack([y_center_output, x_center_output], axis=1) y_center_output = (ymin + ymax) * output_height / 2.0
x_center_output = (xmin + xmax) * output_width / 2.0
center_coords_output = tf.stack([y_center_output, x_center_output], axis=2)
center_coords_output_int = tf.cast(center_coords_output, tf.int32) center_coords_output_int = tf.cast(center_coords_output, tf.int32)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int,
batch_dims=1)
return center_latents return center_latents
def predict(self, preprocessed_inputs, other_inputs):
prediction_dict = super(DeepMACMetaArch, self).predict(
preprocessed_inputs, other_inputs)
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
return prediction_dict
def _predict_mask_logits_from_embeddings(
self, pixel_embedding, instance_embedding, boxes):
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
mask_input, batch_size, num_instances = flatten_first2_dims(mask_input)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
instance_embeddings, _, _ = flatten_first2_dims(instance_embeddings)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_logits = unpack_first2_dims(
mask_logits, batch_size, num_instances)
return mask_logits
def _predict_mask_logits_from_gt_boxes(self, prediction_dict):
mask_logits_list = []
boxes = _batch_gt_list(self.groundtruth_lists(fields.BoxListFields.boxes))
instance_embedding_list = prediction_dict[INSTANCE_EMBEDDING]
pixel_embedding_list = prediction_dict[PIXEL_EMBEDDING]
if self._deepmac_params.use_only_last_stage:
instance_embedding_list = [instance_embedding_list[-1]]
pixel_embedding_list = [pixel_embedding_list[-1]]
for (instance_embedding, pixel_embedding) in zip(instance_embedding_list,
pixel_embedding_list):
mask_logits_list.append(
self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes))
return mask_logits_list
def _get_groundtruth_mask_output(self, boxes, masks): def _get_groundtruth_mask_output(self, boxes, masks):
"""Get the expected mask output for each box. """Get the expected mask output for each box.
Args: Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in boxes: A [batch_size, num_instances, 4] float tensor containing bounding
normalized coordinates. boxes in normalized coordinates.
masks: A [num_instances, height, width] float tensor containing binary masks: A [batch_size, num_instances, height, width] float tensor
ground truth masks. containing binary ground truth masks.
Returns: Returns:
masks: If predict_full_resolution_masks is set, masks are not resized masks: If predict_full_resolution_masks is set, masks are not resized
and the size of this tensor is [num_instances, input_height, input_width]. and the size of this tensor is [batch_size, num_instances,
Otherwise, returns a tensor of size [num_instances, mask_size, mask_size]. input_height, input_width]. Otherwise, returns a tensor of size
[batch_size, num_instances, mask_size, mask_size].
""" """
mask_size = self._deepmac_params.mask_size mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
return masks return masks
...@@ -957,9 +1063,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -957,9 +1063,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return cropped_masks return cropped_masks
def _resize_logits_like_gt(self, logits, gt): def _resize_logits_like_gt(self, logits, gt):
height, width = tf.shape(gt)[2], tf.shape(gt)[3]
height, width = tf.shape(gt)[1], tf.shape(gt)[2]
return resize_instance_masks(logits, (height, width)) return resize_instance_masks(logits, (height, width))
def _aggregate_classification_loss(self, loss, gt, pred, method): def _aggregate_classification_loss(self, loss, gt, pred, method):
...@@ -1016,54 +1120,59 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1016,54 +1120,59 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
raise ValueError('Unknown loss aggregation - {}'.format(method)) raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_per_instance_mask_prediction_loss( def _compute_mask_prediction_loss(
self, boxes, mask_logits, mask_gt): self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss. """Compute the per-instance mask loss.
Args: Args:
boxes: A [num_instances, 4] float tensor of GT boxes. boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
mask_logits: A [num_instances, height, width] float tensor of predicted mask_logits: A [batch_suze, num_instances, height, width] float tensor of
masks predicted masks
mask_gt: The groundtruth mask. mask_gt: The groundtruth mask of same shape as mask_logits.
Returns: Returns:
loss: A [num_instances] shaped tensor with the loss for each instance. loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance.
""" """
num_instances = tf.shape(boxes)[0] batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt) mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1]) mask_logits = tf.reshape(mask_logits, [batch_size * num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1]) mask_gt = tf.reshape(mask_gt, [batch_size * num_instances, -1, 1])
loss = self._deepmac_params.classification_loss( loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits, prediction_tensor=mask_logits,
target_tensor=mask_gt, target_tensor=mask_gt,
weights=tf.ones_like(mask_logits)) weights=tf.ones_like(mask_logits))
return self._aggregate_classification_loss( loss = self._aggregate_classification_loss(
loss, mask_gt, mask_logits, 'normalize_auto') loss, mask_gt, mask_logits, 'normalize_auto')
return tf.reshape(loss, [batch_size, num_instances])
def _compute_per_instance_box_consistency_loss( def _compute_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits): self, boxes_gt, boxes_for_crop, mask_logits):
"""Compute the per-instance box consistency loss. """Compute the per-instance box consistency loss.
Args: Args:
boxes_gt: A [num_instances, 4] float tensor of GT boxes. boxes_gt: A [batch_size, num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [num_instances, 4] float tensor of augmented boxes, boxes_for_crop: A [batch_size, num_instances, 4] float tensor of
to be used when using crop-and-resize based mask head. augmented boxes, to be used when using crop-and-resize based mask head.
mask_logits: A [num_instances, height, width] float tensor of predicted mask_logits: A [batch_size, num_instances, height, width]
masks. float tensor of predicted masks.
Returns: Returns:
loss: A [num_instances] shaped tensor with the loss for each instance. loss: A [batch_size, num_instances] shaped tensor with the loss for
each instance in the batch.
""" """
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2] shape = tf.shape(mask_logits)
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis] batch_size, num_instances, height, width = (
mask_logits = mask_logits[:, :, :, tf.newaxis] shape[0], shape[1], shape[2], shape[3])
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, :, tf.newaxis]
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
gt_crop = filled_boxes[:, :, :, 0] gt_crop = filled_boxes[:, :, :, :, 0]
pred_crop = mask_logits[:, :, :, 0] pred_crop = mask_logits[:, :, :, :, 0]
else: else:
gt_crop = crop_and_resize_instance_masks( gt_crop = crop_and_resize_instance_masks(
filled_boxes, boxes_for_crop, self._deepmac_params.mask_size) filled_boxes, boxes_for_crop, self._deepmac_params.mask_size)
...@@ -1071,7 +1180,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1071,7 +1180,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits, boxes_for_crop, self._deepmac_params.mask_size) mask_logits, boxes_for_crop, self._deepmac_params.mask_size)
loss = 0.0 loss = 0.0
for axis in [1, 2]: for axis in [2, 3]:
if self._deepmac_params.box_consistency_tightness: if self._deepmac_params.box_consistency_tightness:
pred_max_raw = tf.reduce_max(pred_crop, axis=axis) pred_max_raw = tf.reduce_max(pred_crop, axis=axis)
...@@ -1083,44 +1192,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1083,44 +1192,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
pred_max = tf.reduce_max(pred_crop, axis=axis) pred_max = tf.reduce_max(pred_crop, axis=axis)
pred_max = pred_max[:, :, tf.newaxis] pred_max = pred_max[:, :, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis] gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, :, tf.newaxis]
flat_pred, batch_size, num_instances = flatten_first2_dims(pred_max)
flat_gt, _, _ = flatten_first2_dims(gt_max)
# We use flat tensors while calling loss functions because we
# want the loss per-instance to later multiply with the per-instance
# weight. Flattening the first 2 dims allows us to represent each instance
# in each batch as though they were samples in a larger batch.
raw_loss = self._deepmac_params.classification_loss( raw_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max, prediction_tensor=flat_pred,
target_tensor=gt_max, target_tensor=flat_gt,
weights=tf.ones_like(pred_max)) weights=tf.ones_like(flat_pred))
loss += self._aggregate_classification_loss( agg_loss = self._aggregate_classification_loss(
raw_loss, gt_max, pred_max, raw_loss, flat_gt, flat_pred,
self._deepmac_params.box_consistency_loss_normalize) self._deepmac_params.box_consistency_loss_normalize)
loss += unpack_first2_dims(agg_loss, batch_size, num_instances)
return loss return loss
def _compute_per_instance_color_consistency_loss( def _compute_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits): self, boxes, preprocessed_image, mask_logits):
"""Compute the per-instance color consistency loss. """Compute the per-instance color consistency loss.
Args: Args:
boxes: A [num_instances, 4] float tensor of GT boxes. boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [height, width, 3] float tensor containing the preprocessed_image: A [batch_size, height, width, 3]
preprocessed image. float tensor containing the preprocessed image.
mask_logits: A [num_instances, height, width] float tensor of predicted mask_logits: A [batch_size, num_instances, height, width] float tensor of
masks. predicted masks.
Returns: Returns:
loss: A [num_instances] shaped tensor with the loss for each instance. loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance fpr each sample in the batch.
""" """
if not self._deepmac_params.predict_full_resolution_masks: if not self._deepmac_params.predict_full_resolution_masks:
logging.info('Color consistency is not implemented with RoIAlign ' logging.info('Color consistency is not implemented with RoIAlign '
', i.e, fixed sized masks. Returning 0 loss.') ', i.e, fixed sized masks. Returning 0 loss.')
return tf.zeros(tf.shape(boxes)[0]) return tf.zeros(tf.shape(boxes)[:2])
dilation = self._deepmac_params.color_consistency_dilation dilation = self._deepmac_params.color_consistency_dilation
height, width = (tf.shape(preprocessed_image)[0], height, width = (tf.shape(preprocessed_image)[1],
tf.shape(preprocessed_image)[1]) tf.shape(preprocessed_image)[2])
color_similarity = dilated_cross_pixel_similarity( color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0) preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits) mask_probs = tf.nn.sigmoid(mask_logits)
...@@ -1132,20 +1250,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1132,20 +1250,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
color_similarity_mask = ( color_similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold) color_similarity > self._deepmac_params.color_consistency_threshold)
color_similarity_mask = tf.cast( color_similarity_mask = tf.cast(
color_similarity_mask[:, tf.newaxis, :, :], tf.float32) color_similarity_mask[:, :, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask * per_pixel_loss = -(color_similarity_mask *
tf.math.log(same_mask_label_probability)) tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps. # TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width) box_mask = fill_boxes(boxes, height, width)
box_mask_expanded = box_mask[tf.newaxis, :, :, :] box_mask_expanded = box_mask[tf.newaxis]
per_pixel_loss = per_pixel_loss * box_mask_expanded per_pixel_loss = per_pixel_loss * box_mask_expanded
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 2, 3]) loss = tf.reduce_sum(per_pixel_loss, axis=[0, 3, 4])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2])) num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[2, 3]))
loss = loss / num_box_pixels loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
self._is_training): tf.keras.backend.learning_phase()):
training_step = tf.cast(self.training_step, tf.float32) training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast( warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32) self._deepmac_params.color_consistency_warmup_steps, tf.float32)
...@@ -1157,56 +1275,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1157,56 +1275,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return loss return loss
def _compute_per_instance_deepmac_losses( def _compute_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding, self, boxes, masks_logits, masks_gt, image):
image):
"""Returns the mask loss per instance. """Returns the mask loss per instance.
Args: Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The boxes: A [batch_size, num_instances, 4] float tensor holding bounding
coordinates are in normalized input space. boxes. The coordinates are in normalized input space.
masks: A [num_instances, input_height, input_width] float tensor masks_logits: A [batch_size, num_instances, input_height, input_width]
containing the instance masks. float tensor containing the instance mask predictions in their logit
instance_embedding: A [output_height, output_width, embedding_size] form.
float tensor containing the instance embeddings. masks_gt: A [batch_size, num_instances, input_height, input_width] float
pixel_embedding: optional [output_height, output_width, tensor containing the groundtruth masks.
pixel_embedding_size] float tensor containing the per-pixel embeddings. image: [batch_size, output_height, output_width, channels] float tensor
image: [output_height, output_width, channels] float tensor
denoting the input image. denoting the input image.
Returns: Returns:
mask_prediction_loss: A [num_instances] shaped float tensor containing the mask_prediction_loss: A [batch_size, num_instances] shaped float tensor
mask loss for each instance. containing the mask loss for each instance in the batch.
box_consistency_loss: A [num_instances] shaped float tensor containing box_consistency_loss: A [batch_size, num_instances] shaped float tensor
the box consistency loss for each instance. containing the box consistency loss for each instance in the batch.
box_consistency_loss: A [num_instances] shaped float tensor containing box_consistency_loss: A [batch_size, num_instances] shaped float tensor
the color consistency loss. containing the color consistency loss in the batch.
""" """
if tf.keras.backend.learning_phase(): if tf.keras.backend.learning_phase():
boxes_for_crop = preprocessor.random_jitter_boxes( boxes = tf.stop_gradient(boxes)
boxes, self._deepmac_params.max_roi_jitter_ratio, def jitter_func(boxes):
jitter_mode=self._deepmac_params.roi_jitter_mode) return preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
boxes_for_crop = tf.map_fn(jitter_func,
boxes, parallel_iterations=128)
else: else:
boxes_for_crop = boxes boxes_for_crop = boxes
mask_input = self._get_mask_head_input( mask_gt = self._get_groundtruth_mask_output(
boxes_for_crop, pixel_embedding) boxes_for_crop, masks_gt)
instance_embeddings = self._get_instance_embeddings(
boxes_for_crop, instance_embedding)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_gt = self._get_groundtruth_mask_output(boxes_for_crop, masks)
mask_prediction_loss = self._compute_per_instance_mask_prediction_loss( mask_prediction_loss = self._compute_mask_prediction_loss(
boxes_for_crop, mask_logits, mask_gt) boxes_for_crop, masks_logits, mask_gt)
box_consistency_loss = self._compute_per_instance_box_consistency_loss( box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, mask_logits) boxes, boxes_for_crop, masks_logits)
color_consistency_loss = self._compute_per_instance_color_consistency_loss( color_consistency_loss = self._compute_color_consistency_loss(
boxes, image, mask_logits) boxes, image, masks_logits)
return { return {
DEEP_MASK_ESTIMATION: mask_prediction_loss, DEEP_MASK_ESTIMATION: mask_prediction_loss,
...@@ -1224,7 +1339,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1224,7 +1339,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
' consistency loss is not supported in TF1.')) ' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image) return tfio.experimental.color.rgb_to_lab(raw_image)
def _compute_instance_masks_loss(self, prediction_dict): def _compute_masks_loss(self, prediction_dict):
"""Computes the mask loss. """Computes the mask loss.
Args: Args:
...@@ -1236,10 +1351,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1236,10 +1351,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
Returns: Returns:
loss_dict: A dict mapping string (loss names) to scalar floats. loss_dict: A dict mapping string (loss names) to scalar floats.
""" """
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
allowed_masked_classes_ids = ( allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids) self._deepmac_params.allowed_masked_classes_ids)
...@@ -1248,8 +1359,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1248,8 +1359,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for loss_name in MASK_LOSSES: for loss_name in MASK_LOSSES:
loss_dict[loss_name] = 0.0 loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0]) prediction_shape = tf.shape(prediction_dict[MASK_LOGITS_GT_BOXES][0])
height, width = prediction_shape[1], prediction_shape[2] height, width = prediction_shape[2], prediction_shape[3]
preprocessed_image = tf.image.resize( preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width)) prediction_dict['preprocessed_inputs'], (height, width))
...@@ -1258,42 +1369,46 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1258,42 +1369,46 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
# TODO(vighneshb) See if we can save memory by only using the final # TODO(vighneshb) See if we can save memory by only using the final
# prediction # prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2) # Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip(
prediction_dict[INSTANCE_EMBEDDING], gt_boxes = _batch_gt_list(
prediction_dict[PIXEL_EMBEDDING]): self.groundtruth_lists(fields.BoxListFields.boxes))
# Iterate over samples in batch gt_weights = _batch_gt_list(
# TODO(vighneshb) find out how autograph is handling this. Converting self.groundtruth_lists(fields.BoxListFields.weights))
# to a single op may give speed improvements gt_masks = _batch_gt_list(
for i, (boxes, weights, classes, masks) in enumerate( self.groundtruth_lists(fields.BoxListFields.masks))
zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)): gt_classes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.classes))
# TODO(vighneshb) Add sub-sampling back if required.
classes, valid_mask_weights, masks = filter_masked_classes( mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES]
allowed_masked_classes_ids, classes, weights, masks) for mask_logits in mask_logits_list:
sample_loss_dict = self._compute_per_instance_deepmac_losses( # TODO(vighneshb) Add sub-sampling back if required.
boxes, masks, instance_pred[i], pixel_pred[i], image[i]) _, valid_mask_weights, gt_masks = filter_masked_classes(
allowed_masked_classes_ids, gt_classes,
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights gt_weights, gt_masks)
for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= weights sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image)
num_instances = tf.maximum(tf.reduce_sum(weights), 1.0)
num_instances_allowed = tf.maximum( sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
tf.reduce_sum(valid_mask_weights), 1.0) for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= gt_weights
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) / num_instances = tf.maximum(tf.reduce_sum(gt_weights), 1.0)
num_instances_allowed) num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) / loss_dict[DEEP_MASK_ESTIMATION] += (
num_instances) tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) /
num_instances_allowed)
batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING]) for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) /
return dict((key, loss / float(batch_size * num_predictions)) num_instances)
num_predictions = len(mask_logits_list)
return dict((key, loss / float(num_predictions))
for key, loss in loss_dict.items()) for key, loss in loss_dict.items())
def loss(self, prediction_dict, true_image_shapes, scope=None): def loss(self, prediction_dict, true_image_shapes, scope=None):
...@@ -1302,7 +1417,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1302,7 +1417,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict, true_image_shapes, scope) prediction_dict, true_image_shapes, scope)
if self._deepmac_params is not None: if self._deepmac_params is not None:
mask_loss_dict = self._compute_instance_masks_loss( mask_loss_dict = self._compute_masks_loss(
prediction_dict=prediction_dict) prediction_dict=prediction_dict)
for loss_name in MASK_LOSSES: for loss_name in MASK_LOSSES:
...@@ -1363,50 +1478,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1363,50 +1478,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_size] containing binary per-box instance masks. mask_size] containing binary per-box instance masks.
""" """
def process(elems): height, width = (tf.shape(instance_embedding)[1],
boxes, instance_embedding, pixel_embedding = elems tf.shape(instance_embedding)[2])
return self._postprocess_sample(boxes, instance_embedding,
pixel_embedding)
max_instances = self._center_params.max_box_predictions
return tf.map_fn(process, [boxes_output_stride, instance_embedding,
pixel_embedding],
dtype=tf.float32, parallel_iterations=max_instances)
def _postprocess_sample(self, boxes_output_stride,
instance_embedding, pixel_embedding):
"""Post process masks for a single sample.
Args:
boxes_output_stride: A [num_instances, 4] float tensor containing
bounding boxes in the absolute output space.
instance_embedding: A [output_height, output_width, embedding_size]
float tensor containing instance embeddings.
pixel_embedding: A [batch_size, output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embedding.
Returns:
masks: A float tensor of size [num_instances, mask_height, mask_width]
containing binary per-box instance masks. If
predict_full_resolution_masks is set, the masks will be resized to
postprocess_crop_size. Otherwise, mask_height=mask_width=mask_size
"""
height, width = (tf.shape(instance_embedding)[0],
tf.shape(instance_embedding)[1])
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
blist = box_list.BoxList(boxes_output_stride) ymin, xmin, ymax, xmax = tf.unstack(boxes_output_stride, axis=2)
blist = box_list_ops.to_normalized_coordinates( ymin /= height
blist, height, width, check_range=False) ymax /= height
boxes = blist.get() xmin /= width
xmax /= width
mask_input = self._get_mask_head_input(boxes, pixel_embedding) boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
mask_logits = self._mask_net( mask_logits = self._predict_mask_logits_from_embeddings(
instance_embeddings, mask_input, pixel_embedding, instance_embedding, boxes)
training=tf.keras.backend.learning_phase())
# TODO(vighneshb) Explore sweeping mask thresholds. # TODO(vighneshb) Explore sweeping mask thresholds.
...@@ -1416,7 +1499,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1416,7 +1499,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
height *= self._stride height *= self._stride
width *= self._stride width *= self._stride
mask_logits = resize_instance_masks(mask_logits, (height, width)) mask_logits = resize_instance_masks(mask_logits, (height, width))
mask_logits = crop_masks_within_boxes(
mask_logits = crop_and_resize_instance_masks(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size) mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
masks_prob = tf.nn.sigmoid(mask_logits) masks_prob = tf.nn.sigmoid(mask_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment