Commit d305396d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

ShapeMask Heads and Losses.

PiperOrigin-RevId: 312624281
parent eb49ae73
......@@ -29,6 +29,11 @@ def evaluator_generator(params):
elif params.type == 'box_and_mask':
evaluator = coco_evaluator.COCOEvaluator(
annotation_file=params.val_json_file, include_mask=True)
elif params.type == 'shapemask_box_and_mask':
evaluator = coco_evaluator.ShapeMaskCOCOEvaluator(
mask_eval_class=params.mask_eval_class,
annotation_file=params.val_json_file, include_mask=True)
else:
raise ValueError('Evaluator %s is not supported.' % params.type)
......
......@@ -85,8 +85,8 @@ def retinanet_head_generator(params):
def rpn_head_generator(params):
head_params = params.rpn_head
"""Generator function for RPN head architecture."""
head_params = params.rpn_head
return heads.RpnHead(
params.architecture.min_level,
params.architecture.max_level,
......@@ -126,3 +126,38 @@ def mask_rcnn_head_generator(params):
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def shapeprior_head_generator(params):
"""Generator function for shape prior head architecture."""
head_params = params.shapemask_head
return heads.ShapemaskPriorHead(
params.architecture.num_classes,
head_params.num_downsample_channels,
head_params.mask_crop_size,
head_params.use_category_for_mask,
head_params.shape_prior_path)
def coarsemask_head_generator(params):
"""Generator function for ShapeMask coarse mask head architecture."""
head_params = params.shapemask_head
return heads.ShapemaskCoarsemaskHead(
params.architecture.num_classes,
head_params.num_downsample_channels,
head_params.mask_crop_size,
head_params.use_category_for_mask,
head_params.num_convs,
norm_activation=norm_activation_generator(params.norm_activation))
def finemask_head_generator(params):
"""Generator function for Shapemask fine mask head architecture."""
head_params = params.shapemask_head
return heads.ShapemaskFinemaskHead(
params.architecture.num_classes,
head_params.num_downsample_channels,
head_params.mask_crop_size,
head_params.use_category_for_mask,
head_params.num_convs,
head_params.upsample_factor)
......@@ -19,9 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
import pickle
from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend
......@@ -56,6 +54,7 @@ class RpnHead(tf.keras.layers.Layer):
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
......@@ -165,6 +164,7 @@ class FastrcnnHead(tf.keras.layers.Layer):
predictions.
fc_dims: `int` number that represents the number of dimension of the FC
layers.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
......@@ -296,6 +296,7 @@ class MaskrcnnHead(tf.keras.layers.Layer):
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
......@@ -566,8 +567,8 @@ class RetinanetHead(object):
with self._class_name_scope:
for i in range(self._num_convs):
features = self._class_conv[i](features)
# The convolution layers in the class net are shared among all levels, but
# each level has its batch normlization to capture the statistical
# The convolution layers in the class net are shared among all levels,
# but each level has its batch normlization to capture the statistical
# difference among different levels.
name = self._class_net_batch_norm_name(i, level)
features = self._class_norm_activation[name](
......@@ -601,12 +602,7 @@ class ShapemaskPriorHead(object):
num_downsample_channels,
mask_crop_size,
use_category_for_mask,
num_of_instances,
min_mask_level,
max_mask_level,
num_clusters,
temperature,
shape_prior_path=None):
shape_prior_path):
"""Initialize params to build RetinaNet head.
Args:
......@@ -614,30 +610,18 @@ class ShapemaskPriorHead(object):
num_downsample_channels: number of channels in mask branch.
mask_crop_size: feature crop size.
use_category_for_mask: use class information in mask branch.
num_of_instances: number of instances to sample in training time.
min_mask_level: minimum FPN level to crop mask feature from.
max_mask_level: maximum FPN level to crop mask feature from.
num_clusters: number of clusters to use in K-Means.
temperature: the temperature for shape prior learning.
shape_prior_path: the path to load shape priors.
"""
self._mask_num_classes = num_classes
self._mask_num_classes = num_classes if use_category_for_mask else 1
self._num_downsample_channels = num_downsample_channels
self._mask_crop_size = mask_crop_size
self._use_category_for_mask = use_category_for_mask
self._num_of_instances = num_of_instances
self._min_mask_level = min_mask_level
self._max_mask_level = max_mask_level
self._num_clusters = num_clusters
self._temperature = temperature
self._shape_prior_path = shape_prior_path
self._use_category_for_mask = use_category_for_mask
self._shape_prior_fc = tf.keras.layers.Dense(
self._num_downsample_channels, name='shape-prior-fc')
def __call__(self,
fpn_features,
boxes,
outer_boxes,
classes,
is_training=None):
def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training):
"""Generate the detection priors from the box detections and FPN features.
This corresponds to the Fig. 4 of the ShapeMask paper at
......@@ -654,221 +638,96 @@ class ShapemaskPriorHead(object):
is_training: training mode or not.
Returns:
crop_features: a float Tensor of shape [batch_size * num_instances,
instance_features: a float Tensor of shape [batch_size * num_instances,
mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
instance feature crop.
detection_priors: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, 1].
"""
with backend.get_graph().as_default():
# loads class specific or agnostic shape priors
if self._shape_prior_path:
if self._use_category_for_mask:
fid = tf.io.gfile.GFile(self._shape_prior_path, 'rb')
# The encoding='bytes' options is for incompatibility between python2
# and python3 pickle.
class_tups = pickle.load(fid, encoding='bytes')
max_class_id = class_tups[-1][0] + 1
class_masks = np.zeros((max_class_id, self._num_clusters,
self._mask_crop_size, self._mask_crop_size),
dtype=np.float32)
for cls_id, _, cls_mask in class_tups:
assert cls_mask.shape == (self._num_clusters,
self._mask_crop_size**2)
class_masks[cls_id] = cls_mask.reshape(self._num_clusters,
self._mask_crop_size,
self._mask_crop_size)
self.class_priors = tf.convert_to_tensor(
value=class_masks, dtype=tf.float32)
else:
npy_path = tf.io.gfile.GFile(self._shape_prior_path)
class_np_masks = np.load(npy_path)
assert class_np_masks.shape == (
self._num_clusters, self._mask_crop_size,
self._mask_crop_size), 'Invalid priors!!!'
self.class_priors = tf.convert_to_tensor(
value=class_np_masks, dtype=tf.float32)
else:
self.class_priors = tf.zeros(
[self._num_clusters, self._mask_crop_size, self._mask_crop_size],
tf.float32)
batch_size = boxes.get_shape()[0]
min_level_shape = fpn_features[self._min_mask_level].get_shape().as_list()
self._max_feature_size = min_level_shape[1]
detection_prior_levels = self._compute_box_levels(boxes)
level_outer_boxes = outer_boxes / tf.pow(
2., tf.expand_dims(detection_prior_levels, -1))
detection_prior_levels = tf.cast(detection_prior_levels, tf.int32)
with backend.get_graph().as_default(), tf.name_scope('prior_mask'):
batch_size, num_instances, _ = boxes.get_shape().as_list()
outer_boxes = tf.cast(outer_boxes, tf.float32)
boxes = tf.cast(boxes, tf.float32)
instance_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, outer_boxes, output_size=self._mask_crop_size)
instance_features = self._shape_prior_fc(instance_features)
shape_priors = self._get_priors()
# Get uniform priors for each outer box.
uniform_priors = tf.ones([batch_size, num_instances, self._mask_crop_size,
self._mask_crop_size])
uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
tf.ones([
batch_size, self._num_of_instances, self._mask_crop_size,
self._mask_crop_size
], tf.float32), boxes, outer_boxes, self._mask_crop_size)
# Prepare crop features.
multi_level_features = self._get_multilevel_features(fpn_features)
crop_features = spatial_transform_ops.single_level_feature_crop(
multi_level_features, level_outer_boxes, detection_prior_levels,
self._min_mask_level, self._mask_crop_size)
# Predict and fuse shape priors.
shape_weights = self._classify_and_fuse_detection_priors(
uniform_priors, classes, crop_features)
fused_shape_priors = self._fuse_priors(shape_weights, classes)
fused_shape_priors = tf.reshape(fused_shape_priors, [
batch_size, self._num_of_instances, self._mask_crop_size,
self._mask_crop_size
])
predicted_detection_priors = spatial_transform_ops.crop_mask_in_target_box(
fused_shape_priors, boxes, outer_boxes, self._mask_crop_size)
predicted_detection_priors = tf.reshape(
predicted_detection_priors,
[-1, self._mask_crop_size, self._mask_crop_size, 1])
return crop_features, predicted_detection_priors
def _get_multilevel_features(self, fpn_features):
"""Get multilevel features from FPN feature dictionary into one tensor.
Args:
fpn_features: a dictionary of FPN features.
Returns:
features: a float tensor of shape [batch_size, num_levels,
max_feature_size, max_feature_size, num_downsample_channels].
"""
# TODO(yeqing): Recover reuse=tf.AUTO_REUSE logic.
with tf.name_scope('masknet'):
mask_feats = {}
# Reduce the feature dimension at each FPN level by convolution.
for feat_level in range(self._min_mask_level, self._max_mask_level + 1):
mask_feats[feat_level] = tf.keras.layers.Conv2D(
self._num_downsample_channels,
kernel_size=(1, 1),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
padding='same',
name='mask-downsample')(
fpn_features[feat_level])
# Concat features through padding to the max size.
features = [mask_feats[self._min_mask_level]]
for feat_level in range(self._min_mask_level + 1,
self._max_mask_level + 1):
features.append(tf.image.pad_to_bounding_box(
mask_feats[feat_level], 0, 0,
self._max_feature_size, self._max_feature_size))
features = tf.stack(features, axis=1)
return features
def _compute_box_levels(self, boxes):
"""Compute the box FPN levels.
Args:
boxes: a float tensor of shape [batch_size, num_instances, 4].
Returns:
levels: a int tensor of shape [batch_size, num_instances].
"""
object_sizes = tf.stack([
boxes[:, :, 2] - boxes[:, :, 0],
boxes[:, :, 3] - boxes[:, :, 1],
], axis=2)
object_sizes = tf.reduce_max(input_tensor=object_sizes, axis=2)
ratios = object_sizes / self._mask_crop_size
levels = tf.math.ceil(tf.math.log(ratios) / tf.math.log(2.))
levels = tf.maximum(tf.minimum(levels, self._max_mask_level),
self._min_mask_level)
return levels
def _classify_and_fuse_detection_priors(self, uniform_priors,
detection_prior_classes,
crop_features):
uniform_priors, boxes, outer_boxes, self._mask_crop_size)
# Classify shape priors using uniform priors + instance features.
prior_distribution = self._classify_shape_priors(
tf.cast(instance_features, tf.float32), uniform_priors, classes)
instance_priors = tf.gather(shape_priors, classes)
instance_priors *= tf.expand_dims(tf.expand_dims(
tf.cast(prior_distribution, tf.float32), axis=-1), axis=-1)
instance_priors = tf.reduce_sum(instance_priors, axis=2)
detection_priors = spatial_transform_ops.crop_mask_in_target_box(
instance_priors, boxes, outer_boxes, self._mask_crop_size)
return instance_features, detection_priors
def _get_priors(self):
"""Load shape priors from file."""
# loads class specific or agnostic shape priors
if self._shape_prior_path:
# Priors are loaded into shape [mask_num_classes, num_clusters, 32, 32].
priors = np.load(tf.io.gfile.GFile(self._shape_prior_path, 'rb'))
priors = tf.convert_to_tensor(priors, dtype=tf.float32)
self._num_clusters = priors.get_shape().as_list()[1]
else:
# If prior path does not exist, do not use priors, i.e., pirors equal to
# uniform empty 32x32 patch.
self._num_clusters = 1
priors = tf.zeros([self._mask_num_classes, self._num_clusters,
self._mask_crop_size, self._mask_crop_size])
return priors
def _classify_shape_priors(self, features, uniform_priors, classes):
"""Classify the uniform prior by predicting the shape modes.
Classify the object crop features into K modes of the clusters for each
category.
Args:
features: A float Tensor of shape [batch_size, num_instances,
mask_size, mask_size, num_channels].
uniform_priors: A float Tensor of shape [batch_size, num_instances,
mask_size, mask_size] representing the uniform detection priors.
detection_prior_classes: A int Tensor of shape [batch_size, num_instances]
classes: A int Tensor of shape [batch_size, num_instances]
of detection class ids.
crop_features: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, num_channels].
Returns:
shape_weights: A float Tensor of shape
[batch_size * num_instances, num_clusters] representing the classifier
prior_distribution: A float Tensor of shape
[batch_size, num_instances, num_clusters] representing the classifier
output probability over all possible shapes.
"""
location_detection_priors = tf.reshape(
uniform_priors, [-1, self._mask_crop_size, self._mask_crop_size, 1])
# Generate image embedding to shape.
fused_shape_features = crop_features * location_detection_priors
shape_embedding = tf.reduce_mean(
input_tensor=fused_shape_features, axis=(1, 2))
if not self._use_category_for_mask:
# TODO(weicheng) use custom op for performance
shape_logits = tf.keras.layers.Dense(
self._num_clusters,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01))(
shape_embedding)
shape_logits = tf.reshape(shape_logits,
[-1, self._num_clusters]) / self._temperature
shape_weights = tf.nn.softmax(shape_logits, name='shape_prior_weights')
else:
shape_logits = tf.keras.layers.Dense(
self._mask_num_classes * self._num_clusters,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01))(
shape_embedding)
shape_logits = tf.reshape(
shape_logits, [-1, self._mask_num_classes, self._num_clusters])
training_classes = tf.reshape(detection_prior_classes, [-1])
class_idx = tf.stack(
[tf.range(tf.size(input=training_classes)), training_classes - 1],
axis=1)
shape_logits = tf.gather_nd(shape_logits, class_idx) / self._temperature
shape_weights = tf.nn.softmax(shape_logits, name='shape_prior_weights')
return shape_weights
def _fuse_priors(self, shape_weights, detection_prior_classes):
"""Fuse shape priors by the predicted shape probability.
Args:
shape_weights: A float Tensor of shape [batch_size * num_instances,
num_clusters] of predicted shape probability distribution.
detection_prior_classes: A int Tensor of shape [batch_size, num_instances]
of detection class ids.
Returns:
detection_priors: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, 1].
"""
batch_size, num_instances, _, _, _ = features.get_shape().as_list()
features *= tf.expand_dims(uniform_priors, axis=-1)
# Reduce spatial dimension of features. The features have shape
# [batch_size, num_instances, num_channels].
features = tf.reduce_mean(features, axis=(2, 3))
logits = tf.keras.layers.Dense(
self._mask_num_classes * self._num_clusters,
kernel_initializer=tf.random_normal_initializer(stddev=0.01))(features)
logits = tf.reshape(logits,
[batch_size, num_instances,
self._mask_num_classes, self._num_clusters])
if self._use_category_for_mask:
object_class_priors = tf.gather(
self.class_priors, detection_prior_classes)
logits = tf.gather(logits, tf.expand_dims(classes, axis=-1), batch_dims=2)
logits = tf.squeeze(logits, axis=2)
else:
num_batch_instances = shape_weights.get_shape()[0]
object_class_priors = tf.tile(
tf.expand_dims(self.class_priors, 0),
[num_batch_instances, 1, 1, 1])
vector_class_priors = tf.reshape(
object_class_priors,
[-1, self._num_clusters,
self._mask_crop_size * self._mask_crop_size])
detection_priors = tf.matmul(
tf.expand_dims(shape_weights, 1), vector_class_priors)[:, 0, :]
detection_priors = tf.reshape(
detection_priors, [-1, self._mask_crop_size, self._mask_crop_size, 1])
return detection_priors
logits = logits[:, :, 0, :]
distribution = tf.nn.softmax(logits, name='shape_prior_weights')
return distribution
class ShapemaskCoarsemaskHead(object):
......@@ -879,7 +738,8 @@ class ShapemaskCoarsemaskHead(object):
num_downsample_channels,
mask_crop_size,
use_category_for_mask,
num_convs):
num_convs,
norm_activation=nn_ops.norm_activation_builder()):
"""Initialize params to build ShapeMask coarse and fine prediction head.
Args:
......@@ -889,118 +749,106 @@ class ShapemaskCoarsemaskHead(object):
use_category_for_mask: use class information in mask branch.
num_convs: `int` number of stacked convolution before the last prediction
layer.
norm_activation: an operation that includes a normalization layer
followed by an optional activation layer.
"""
self._mask_num_classes = num_classes
self._mask_num_classes = num_classes if use_category_for_mask else 1
self._use_category_for_mask = use_category_for_mask
self._num_downsample_channels = num_downsample_channels
self._mask_crop_size = mask_crop_size
self._use_category_for_mask = use_category_for_mask
self._num_convs = num_convs
if not use_category_for_mask:
assert num_classes == 1
def __call__(self,
crop_features,
detection_priors,
inst_classes,
is_training=None):
self._norm_activation = norm_activation
self._coarse_mask_fc = tf.keras.layers.Dense(
self._num_downsample_channels, name='coarse-mask-fc')
self._class_conv = []
self._class_norm_activation = []
for i in range(self._num_convs):
self._class_conv.append(tf.keras.layers.Conv2D(
self._num_downsample_channels,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
padding='same',
name='coarse-mask-class-%d' % i))
self._class_norm_activation.append(
norm_activation(name='coarse-mask-class-%d-bn' % i))
self._class_predict = tf.keras.layers.Conv2D(
self._mask_num_classes,
kernel_size=(1, 1),
# Focal loss bias initialization to have foreground 0.01 probability.
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
padding='same',
name='coarse-mask-class-predict')
def __call__(self, features, detection_priors, classes, is_training):
"""Generate instance masks from FPN features and detection priors.
This corresponds to the Fig. 5-6 of the ShapeMask paper at
https://arxiv.org/pdf/1904.03239.pdf
Args:
crop_features: a float Tensor of shape [batch_size * num_instances,
features: a float Tensor of shape [batch_size, num_instances,
mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
instance feature crop.
detection_priors: a float Tensor of shape [batch_size * num_instances,
detection_priors: a float Tensor of shape [batch_size, num_instances,
mask_crop_size, mask_crop_size, 1]. This is the detection prior for
the instance.
inst_classes: a int Tensor of shape [batch_size, num_instances]
classes: a int Tensor of shape [batch_size, num_instances]
of instance classes.
is_training: a bool indicating whether in training mode.
Returns:
mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size * num_instances, mask_size, mask_size, num_classes].
[batch_size, num_instances, mask_size, mask_size].
"""
# Embed the anchor map into some feature space for anchor conditioning.
detection_prior_features = tf.keras.layers.Conv2D(
self._num_downsample_channels,
kernel_size=(1, 1),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0., stddev=0.01),
padding='same',
name='anchor-conv')(
detection_priors)
prior_conditioned_features = crop_features + detection_prior_features
coarse_output_features = self.coarsemask_decoder_net(
prior_conditioned_features, is_training)
coarse_mask_classes = tf.keras.layers.Conv2D(
self._mask_num_classes,
kernel_size=(1, 1),
# Focal loss bias initialization to have foreground 0.01 probability.
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0, stddev=0.01),
padding='same',
name='class-predict')(
coarse_output_features)
if self._use_category_for_mask:
inst_classes = tf.cast(tf.reshape(inst_classes, [-1]), tf.int32)
coarse_mask_classes_t = tf.transpose(
a=coarse_mask_classes, perm=(0, 3, 1, 2))
# pylint: disable=g-long-lambda
coarse_mask_logits = tf.cond(
pred=tf.size(input=inst_classes) > 0,
true_fn=lambda: tf.gather_nd(
coarse_mask_classes_t,
tf.stack(
[tf.range(tf.size(input=inst_classes)), inst_classes - 1],
axis=1)),
false_fn=lambda: coarse_mask_classes_t[:, 0, :, :])
# pylint: enable=g-long-lambda
coarse_mask_logits = tf.expand_dims(coarse_mask_logits, -1)
else:
coarse_mask_logits = coarse_mask_classes
coarse_class_probs = tf.nn.sigmoid(coarse_mask_logits)
class_probs = tf.cast(coarse_class_probs, prior_conditioned_features.dtype)
with backend.get_graph().as_default(), tf.name_scope('coarse_mask'):
# Transform detection priors to have the same dimension as features.
detection_priors = tf.expand_dims(detection_priors, axis=-1)
detection_priors = self._coarse_mask_fc(detection_priors)
features += detection_priors
mask_logits = self.decoder_net(features, is_training)
# Gather the logits with right input class.
if self._use_category_for_mask:
mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
mask_logits = tf.gather(mask_logits, tf.expand_dims(classes, -1),
batch_dims=2)
mask_logits = tf.squeeze(mask_logits, axis=2)
else:
mask_logits = mask_logits[..., 0]
return coarse_mask_classes, class_probs, prior_conditioned_features
return mask_logits
def coarsemask_decoder_net(self,
images,
is_training=None,
norm_activation=nn_ops.norm_activation_builder()):
def decoder_net(self, features, is_training=False):
"""Coarse mask decoder network architecture.
Args:
images: A tensor of size [batch, height_in, width_in, channels_in].
features: A tensor of size [batch, height_in, width_in, channels_in].
is_training: Whether batch_norm layers are in training mode.
norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional).
Returns:
images: A feature tensor of size [batch, output_size, output_size,
num_channels]
"""
(batch_size, num_instances, height, width,
num_channels) = features.get_shape().as_list()
features = tf.reshape(features, [batch_size * num_instances, height, width,
num_channels])
for i in range(self._num_convs):
images = tf.keras.layers.Conv2D(
self._num_downsample_channels,
kernel_size=(3, 3),
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
activation=None,
padding='same',
name='coarse-class-%d' % i)(
images)
images = norm_activation(name='coarse-class-%d-bn' % i)(
images, is_training=is_training)
features = self._class_conv[i](features)
features = self._class_norm_activation[i](features,
is_training=is_training)
return images
mask_logits = self._class_predict(features)
mask_logits = tf.reshape(mask_logits, [batch_size, num_instances, height,
width, self._mask_num_classes])
return mask_logits
class ShapemaskFinemaskHead(object):
......@@ -1010,9 +858,9 @@ class ShapemaskFinemaskHead(object):
num_classes,
num_downsample_channels,
mask_crop_size,
use_category_for_mask,
num_convs,
coarse_mask_thr,
gt_upsample_scale,
upsample_factor,
norm_activation=nn_ops.norm_activation_builder()):
"""Initialize params to build ShapeMask coarse and fine prediction head.
......@@ -1020,33 +868,29 @@ class ShapemaskFinemaskHead(object):
num_classes: `int` number of mask classification categories.
num_downsample_channels: `int` number of filters at mask head.
mask_crop_size: feature crop size.
use_category_for_mask: use class information in mask branch.
num_convs: `int` number of stacked convolution before the last prediction
layer.
coarse_mask_thr: the threshold for suppressing noisy coarse prediction.
gt_upsample_scale: scale for upsampling groundtruths.
upsample_factor: `int` number of fine mask upsampling factor.
norm_activation: an operation that includes a batch normalization layer
followed by a relu layer(optional).
"""
self._mask_num_classes = num_classes
self._use_category_for_mask = use_category_for_mask
self._mask_num_classes = num_classes if use_category_for_mask else 1
self._num_downsample_channels = num_downsample_channels
self._mask_crop_size = mask_crop_size
self._num_convs = num_convs
self._coarse_mask_thr = coarse_mask_thr
self._gt_upsample_scale = gt_upsample_scale
self.up_sample_factor = upsample_factor
self._fine_mask_fc = tf.keras.layers.Dense(
self._num_downsample_channels, name='fine-mask-fc')
self._class_predict_conv = tf.keras.layers.Conv2D(
self._mask_num_classes,
kernel_size=(1, 1),
# Focal loss bias initialization to have foreground 0.01 probability.
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(
mean=0, stddev=0.01),
padding='same',
name='affinity-class-predict')
self._upsample_conv = tf.keras.layers.Conv2DTranspose(
self._num_downsample_channels // 2,
(self._gt_upsample_scale, self._gt_upsample_scale),
(self._gt_upsample_scale, self._gt_upsample_scale))
self._num_downsample_channels,
(self.up_sample_factor, self.up_sample_factor),
(self.up_sample_factor, self.up_sample_factor),
name='fine-mask-conv2d-tran')
self._fine_class_conv = []
self._fine_class_bn = []
for i in range(self._num_convs):
......@@ -1059,60 +903,73 @@ class ShapemaskFinemaskHead(object):
stddev=0.01),
activation=None,
padding='same',
name='fine-class-%d' % i))
self._fine_class_bn.append(norm_activation(name='fine-class-%d-bn' % i))
name='fine-mask-class-%d' % i))
self._fine_class_bn.append(norm_activation(
name='fine-mask-class-%d-bn' % i))
self._class_predict_conv = tf.keras.layers.Conv2D(
self._mask_num_classes,
kernel_size=(1, 1),
# Focal loss bias initialization to have foreground 0.01 probability.
bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
padding='same',
name='fine-mask-class-predict')
def __call__(self, prior_conditioned_features, class_probs, is_training=None):
def __call__(self, features, mask_logits, classes, is_training):
"""Generate instance masks from FPN features and detection priors.
This corresponds to the Fig. 5-6 of the ShapeMask paper at
https://arxiv.org/pdf/1904.03239.pdf
Args:
prior_conditioned_features: a float Tensor of shape [batch_size *
num_instances, mask_crop_size, mask_crop_size, num_downsample_channels].
This is the instance feature crop.
class_probs: a float Tensor of shape [batch_size * num_instances,
mask_crop_size, mask_crop_size, 1]. This is the class probability of
instance segmentation.
features: a float Tensor of shape
[batch_size, num_instances, mask_crop_size, mask_crop_size,
num_downsample_channels]. This is the instance feature crop.
mask_logits: a float Tensor of shape
[batch_size, num_instances, mask_crop_size, mask_crop_size] indicating
predicted mask logits.
classes: a int Tensor of shape [batch_size, num_instances]
of instance classes.
is_training: a bool indicating whether in training mode.
Returns:
mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size * num_instances, mask_size, mask_size, num_classes].
[batch_size, num_instances, mask_size, mask_size].
"""
with backend.get_graph().as_default(), tf.name_scope('affinity-masknet'):
# Extract the foreground mean features
point_samp_prob_thr = 1. / (1. + tf.exp(-self._coarse_mask_thr))
point_samp_prob_thr = tf.cast(point_samp_prob_thr, class_probs.dtype)
class_probs = tf.where(
tf.greater(class_probs, point_samp_prob_thr), class_probs,
tf.zeros_like(class_probs))
weighted_features = class_probs * prior_conditioned_features
sum_class_vector = tf.reduce_sum(
input_tensor=class_probs, axis=(1, 2)) + tf.constant(
1e-20, class_probs.dtype)
# Extract the foreground mean features
# with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
with backend.get_graph().as_default(), tf.name_scope('fine_mask'):
mask_probs = tf.nn.sigmoid(mask_logits)
# Compute instance embedding for hard average.
binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype)
instance_embedding = tf.reduce_sum(
input_tensor=weighted_features, axis=(1, 2)) / sum_class_vector
features * tf.expand_dims(binary_mask, axis=-1), axis=(2, 3))
instance_embedding /= tf.expand_dims(
tf.reduce_sum(binary_mask, axis=(2, 3)) + 1e-20, axis=-1)
# Take the difference between crop features and mean instance features.
instance_features = prior_conditioned_features - tf.reshape(
instance_embedding, (-1, 1, 1, self._num_downsample_channels))
features -= tf.expand_dims(
tf.expand_dims(instance_embedding, axis=2), axis=2)
# Decoder to generate upsampled segmentation mask.
affinity_output_features = self.finemask_decoder_net(
instance_features, is_training)
features += self._fine_mask_fc(tf.expand_dims(mask_probs, axis=-1))
# Predict per-class instance masks.
affinity_mask_classes = self._class_predict_conv(affinity_output_features)
# Decoder to generate upsampled segmentation mask.
mask_logits = self.decoder_net(features, is_training)
if self._use_category_for_mask:
mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
mask_logits = tf.gather(mask_logits,
tf.expand_dims(classes, -1), batch_dims=2)
mask_logits = tf.squeeze(mask_logits, axis=2)
else:
mask_logits = mask_logits[..., 0]
return affinity_mask_classes
return mask_logits
def finemask_decoder_net(self, images, is_training=None):
def decoder_net(self, features, is_training=False):
"""Fine mask decoder network architecture.
Args:
images: A tensor of size [batch, height_in, width_in, channels_in].
features: A tensor of size [batch, height_in, width_in, channels_in].
is_training: Whether batch_norm layers are in training mode.
Returns:
......@@ -1120,11 +977,23 @@ class ShapemaskFinemaskHead(object):
num_channels], where output size is self._gt_upsample_scale times
that of input.
"""
(batch_size, num_instances, height, width,
num_channels) = features.get_shape().as_list()
features = tf.reshape(features, [batch_size * num_instances, height, width,
num_channels])
for i in range(self._num_convs):
images = self._fine_class_conv[i](images)
images = self._fine_class_bn[i](images, is_training=is_training)
features = self._fine_class_conv[i](features)
features = self._fine_class_bn[i](features, is_training=is_training)
if self.up_sample_factor > 1:
features = self._upsample_conv(features)
if self._gt_upsample_scale > 1:
images = self._upsample_conv(images)
# Predict per-class instance masks.
mask_logits = self._class_predict_conv(features)
return images
mask_logits = tf.reshape(mask_logits,
[batch_size, num_instances,
height * self.up_sample_factor,
width * self.up_sample_factor,
self._mask_num_classes])
return mask_logits
......@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import functools
from absl import logging
import tensorflow as tf
......@@ -105,5 +104,5 @@ def norm_activation_builder(momentum=0.997,
momentum=momentum,
epsilon=epsilon,
trainable=trainable,
activation='relu',
activation=activation,
**kwargs)
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses used for Mask-RCNN."""
"""Losses used for detection models."""
from __future__ import absolute_import
from __future__ import division
......@@ -479,12 +479,62 @@ class RetinanetBoxLoss(object):
class ShapemaskMseLoss(object):
"""ShapeMask mask Mean Squared Error loss function wrapper."""
def __init__(self):
raise NotImplementedError('Not Implemented.')
def __call__(self, probs, labels, valid_mask):
"""Compute instance segmentation loss.
Args:
probs: A Tensor of shape [batch_size * num_points, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
labels: A float32/float16 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
valid_mask: a binary mask indicating valid training masks.
Returns:
loss: an float tensor representing total mask classification loss.
"""
with tf.name_scope('shapemask_prior_loss'):
batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
diff = (tf.cast(labels, dtype=tf.float32) -
tf.cast(probs, dtype=tf.float32))
diff *= tf.cast(
tf.reshape(valid_mask, [batch_size, num_instances, 1, 1]),
tf.float32)
# Adding 0.001 in the denominator to avoid division by zero.
loss = tf.nn.l2_loss(diff) / (tf.reduce_sum(labels) + 0.001)
return loss
class ShapemaskLoss(object):
"""ShapeMask mask loss function wrapper."""
def __init__(self):
raise NotImplementedError('Not Implemented.')
self._binary_crossentropy = tf.keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.SUM, from_logits=True)
def __call__(self, logits, labels, valid_mask):
"""ShapeMask mask cross entropy loss function wrapper.
Args:
logits: A Tensor of shape [batch_size * num_instances, height, width,
num_classes]. The logits are not necessarily between 0 and 1.
labels: A float16/float32 Tensor of shape [batch_size, num_instances,
mask_size, mask_size], where mask_size =
mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
for coarse masks and shape priors.
valid_mask: a binary mask of shape [batch_size, num_instances]
indicating valid training masks.
Returns:
loss: an float tensor representing total mask classification loss.
"""
with tf.name_scope('shapemask_loss'):
batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
labels = tf.cast(labels, tf.float32)
logits = tf.cast(logits, tf.float32)
loss = self._binary_crossentropy(labels, logits)
loss *= tf.cast(tf.reshape(
valid_mask, [batch_size, num_instances, 1, 1]), loss.dtype)
# Adding 0.001 in the denominator to avoid division by zero.
loss = tf.reduce_sum(loss) / (tf.reduce_sum(labels) + 0.001)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment