Commit 575c7458 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 336914823
parent eb4f8124
......@@ -145,6 +145,7 @@ class MaskHead(hyperparams.Config):
num_convs: int = 4
num_filters: int = 256
use_separable_conv: bool = False
class_agnostic: bool = False
@dataclasses.dataclass
......
......@@ -156,7 +156,8 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
activation=model_config.norm_activation.activation,
norm_momentum=model_config.norm_activation.norm_momentum,
norm_epsilon=model_config.norm_activation.norm_epsilon,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
class_agnostic=model_config.mask_head.class_agnostic)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=(
......
......@@ -225,6 +225,7 @@ class MaskHead(tf.keras.layers.Layer):
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
class_agnostic=False,
**kwargs):
"""Initialize params to build the mask head.
......@@ -248,6 +249,8 @@ class MaskHead(tf.keras.layers.Layer):
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
class_agnostic: `bool`, if set, we use a single channel mask head that
is shared between all classes.
**kwargs: other keyword arguments passed to Layer.
"""
super(MaskHead, self).__init__(**kwargs)
......@@ -263,6 +266,7 @@ class MaskHead(tf.keras.layers.Layer):
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'class_agnostic': class_agnostic
}
if tf.keras.backend.image_data_format() == 'channels_last':
......@@ -330,8 +334,13 @@ class MaskHead(tf.keras.layers.Layer):
name='mask-upsampling')
self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs)
if self._config_dict['class_agnostic']:
num_filters = 1
else:
num_filters = self._config_dict['num_classes']
conv_kwargs = {
'filters': self._config_dict['num_classes'],
'filters': num_filters,
'kernel_size': 1,
'padding': 'valid',
}
......@@ -395,17 +404,27 @@ class MaskHead(tf.keras.layers.Layer):
mask_height = height * self._config_dict['upsample_factor']
mask_width = width * self._config_dict['upsample_factor']
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
if self._config_dict['class_agnostic']:
logits = tf.reshape(logits, [-1, num_rois, mask_height, mask_width, 1])
else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois])
mask_indices = tf.tile(
tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1])
if self._config_dict['class_agnostic']:
class_gather_indices = tf.zeros_like(roi_classes, dtype=tf.int32)
else:
class_gather_indices = tf.cast(roi_classes, dtype=tf.int32)
gather_indices = tf.stack(
[batch_indices, mask_indices, tf.cast(roi_classes, dtype=tf.int32)],
[batch_indices, mask_indices, class_gather_indices],
axis=2)
mask_outputs = tf.gather_nd(
tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices)
......
......@@ -120,6 +120,16 @@ class MaskHeadTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(
mask_head.get_config(), new_mask_head.get_config())
def test_forward_class_agnostic(self):
mask_head = instance_heads.MaskHead(
num_classes=3,
class_agnostic=True
)
roi_features = np.random.rand(2, 10, 14, 14, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(masks.numpy().shape, [2, 10, 28, 28])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment