Commit 457bcb85 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Refactor DeepMAC to process full batch and return mask logits with predict()

PiperOrigin-RevId: 426181961
parent c3f2134b
...@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object): ...@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
else: else:
raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}') raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}')
heatmap = tf.stop_gradient(heatmap)
heatmaps.append(heatmap) heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch. # Return the stacked heatmaps over the batch.
......
...@@ -30,6 +30,7 @@ if tf_version.is_tf2(): ...@@ -30,6 +30,7 @@ if tf_version.is_tf2():
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING' INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING' PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation' DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency' DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency' DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
...@@ -50,7 +51,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [ ...@@ -50,7 +51,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'box_consistency_loss_weight', 'color_consistency_threshold', 'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight', 'color_consistency_dilation', 'color_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness', 'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start' 'color_consistency_warmup_steps', 'color_consistency_warmup_start',
'use_only_last_stage'
]) ])
...@@ -140,33 +142,24 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None): ...@@ -140,33 +142,24 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Unknown network type {}'.format(name)) raise ValueError('Unknown network type {}'.format(name))
def crop_masks_within_boxes(masks, boxes, output_size): def _resize_instance_masks_non_empty(masks, shape):
"""Crops masks to lie tightly within the boxes. """Resize a non-empty tensor of masks to the given shape."""
height, width = shape
Args: flattened_masks, batch_size, num_instances = flatten_first2_dims(masks)
masks: A [num_instances, height, width] float tensor of masks. flattened_masks = flattened_masks[:, :, :, tf.newaxis]
boxes: A [num_instances, 4] sized tensor of normalized bounding boxes. flattened_masks = tf.image.resize(
output_size: The height and width of the output masks. flattened_masks, (height, width),
method=tf.image.ResizeMethod.BILINEAR)
Returns: return unpack_first2_dims(
masks: A [num_instances, output_size, output_size] tensor of masks which flattened_masks[:, :, :, 0], batch_size, num_instances)
are cropped to be tightly within the gives boxes and resized.
"""
masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[output_size, output_size])
return masks[:, 0, :, :, 0]
def resize_instance_masks(masks, shape): def resize_instance_masks(masks, shape):
height, width = shape batch_size, num_instances = tf.shape(masks)[0], tf.shape(masks)[1]
masks_ex = masks[:, :, :, tf.newaxis] return tf.cond(
masks_ex = tf.image.resize(masks_ex, (height, width), tf.shape(masks)[1] == 0,
method=tf.image.ResizeMethod.BILINEAR) lambda: tf.zeros((batch_size, num_instances, shape[0], shape[1])),
masks = masks_ex[:, :, :, 0] lambda: _resize_instance_masks_non_empty(masks, shape))
return masks
def filter_masked_classes(masked_class_ids, classes, weights, masks): def filter_masked_classes(masked_class_ids, classes, weights, masks):
...@@ -175,94 +168,132 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks): ...@@ -175,94 +168,132 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
Args: Args:
masked_class_ids: A list of class IDs allowed to have masks. These class IDs masked_class_ids: A list of class IDs allowed to have masks. These class IDs
are 1-indexed. are 1-indexed.
classes: A [num_instances, num_classes] float tensor containing the one-hot classes: A [batch_size, num_instances, num_classes] float tensor containing
encoded classes. the one-hot encoded classes.
weights: A [num_instances] float tensor containing the weights of each weights: A [batch_size, num_instances] float tensor containing the weights
sample. of each sample.
masks: A [num_instances, height, width] tensor containing the mask per masks: A [batch_size, num_instances, height, width] tensor containing the
instance. mask per instance.
Returns: Returns:
classes_filtered: A [num_instances, num_classes] float tensor containing the classes_filtered: A [batch_size, num_instances, num_classes] float tensor
one-hot encoded classes with classes not in masked_class_ids zeroed out. containing the one-hot encoded classes with classes not in
weights_filtered: A [num_instances] float tensor containing the weights of masked_class_ids zeroed out.
each sample with instances whose classes aren't in masked_class_ids weights_filtered: A [batch_size, num_instances] float tensor containing the
zeroed out. weights of each sample with instances whose classes aren't in
masks_filtered: A [num_instances, height, width] tensor containing the mask masked_class_ids zeroed out.
per instance with masks not belonging to masked_class_ids zeroed out. masks_filtered: A [batch_size, num_instances, height, width] tensor
containing the mask per instance with masks not belonging to
masked_class_ids zeroed out.
""" """
if len(masked_class_ids) == 0: # pylint:disable=g-explicit-length-test if len(masked_class_ids) == 0: # pylint:disable=g-explicit-length-test
return classes, weights, masks return classes, weights, masks
if tf.shape(classes)[0] == 0: if tf.shape(classes)[1] == 0:
return classes, weights, masks return classes, weights, masks
masked_class_ids = tf.constant(np.array(masked_class_ids, dtype=np.int32)) masked_class_ids = tf.constant(np.array(masked_class_ids, dtype=np.int32))
label_id_offset = 1 label_id_offset = 1
masked_class_ids -= label_id_offset masked_class_ids -= label_id_offset
class_ids = tf.argmax(classes, axis=1, output_type=tf.int32) class_ids = tf.argmax(classes, axis=2, output_type=tf.int32)
matched_classes = tf.equal( matched_classes = tf.equal(
class_ids[:, tf.newaxis], masked_class_ids[tf.newaxis, :] class_ids[:, :, tf.newaxis], masked_class_ids[tf.newaxis, tf.newaxis, :]
) )
matched_classes = tf.reduce_any(matched_classes, axis=1) matched_classes = tf.reduce_any(matched_classes, axis=2)
matched_classes = tf.cast(matched_classes, tf.float32) matched_classes = tf.cast(matched_classes, tf.float32)
return ( return (
classes * matched_classes[:, tf.newaxis], classes * matched_classes[:, :, tf.newaxis],
weights * matched_classes, weights * matched_classes,
masks * matched_classes[:, tf.newaxis, tf.newaxis] masks * matched_classes[:, :, tf.newaxis, tf.newaxis]
) )
def crop_and_resize_feature_map(features, boxes, size): def flatten_first2_dims(tensor):
"""Crop and resize regions from a single feature map given a set of boxes. """Flatten first 2 dimensions of a tensor.
Args: Args:
features: A [H, W, C] float tensor. tensor: A tensor with shape [M, N, ....]
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
Returns: Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized flattened_tensor: A tensor of shape [M * N, ...]
features. M: int, the length of the first dimension of the input.
N: int, the length of the second dimension of the input.
""" """
return spatial_transform_ops.matmul_crop_and_resize( shape = tf.shape(tensor)
features[tf.newaxis], boxes[tf.newaxis], [size, size])[0] d1, d2, rest = shape[0], shape[1], shape[2:]
tensor = tf.reshape(
tensor, tf.concat([[d1 * d2], rest], axis=0))
return tensor, d1, d2
def unpack_first2_dims(tensor, dim1, dim2):
"""Unpack the flattened first dimension of the tensor into 2 dimensions.
Args:
tensor: A tensor of shape [dim1 * dim2, ...]
dim1: int, the size of the first dimension.
dim2: int, the size of the second dimension.
Returns:
unflattened_tensor: A tensor of shape [dim1, dim2, ...].
"""
shape = tf.shape(tensor)
result_shape = tf.concat([[dim1, dim2], shape[1:]], axis=0)
return tf.reshape(tensor, result_shape)
def crop_and_resize_instance_masks(masks, boxes, mask_size): def crop_and_resize_instance_masks(masks, boxes, mask_size):
"""Crop and resize each mask according to the given boxes. """Crop and resize each mask according to the given boxes.
Args: Args:
masks: A [N, H, W] float tensor. masks: A [B, N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes. boxes: A [B, N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks. mask_size: int, the size of the output masks.
Returns: Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized masks: A [B, N, mask_size, mask_size] float tensor of cropped and resized
instance masks. instance masks.
""" """
masks, batch_size, num_instances = flatten_first2_dims(masks)
boxes, _, _ = flatten_first2_dims(boxes)
cropped_masks = spatial_transform_ops.matmul_crop_and_resize( cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :], masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size]) [mask_size, mask_size])
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4]) cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
return unpack_first2_dims(cropped_masks, batch_size, num_instances)
return cropped_masks
def fill_boxes(boxes, height, width): def fill_boxes(boxes, height, width):
"""Fills the area included in the box.""" """Fills the area included in the boxes with 1s.
blist = box_list.BoxList(boxes)
blist = box_list_ops.to_absolute_coordinates(blist, height, width) Args:
boxes = blist.get() boxes: A [batch_size, num_instances, 4] shapes float tensor of boxes given
in the normalized coordinate space.
height: int, height of the output image.
width: int, width of the output image.
Returns:
filled_boxes: A [batch_size, num_instances, height, width] shaped float
tensor with 1s in the area that falls inside each box.
"""
ymin, xmin, ymax, xmax = tf.unstack( ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, tf.newaxis, tf.newaxis, :], 4, axis=3) boxes[:, :, tf.newaxis, tf.newaxis, :], 4, axis=4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin *= height
ymax *= height
xmin *= width
xmax *= width
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij') ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32) ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
ygrid, xgrid = ygrid[tf.newaxis, :, :], xgrid[tf.newaxis, :, :] ygrid, xgrid = (ygrid[tf.newaxis, tf.newaxis, :, :],
xgrid[tf.newaxis, tf.newaxis, :, :])
filled_boxes = tf.logical_and( filled_boxes = tf.logical_and(
tf.logical_and(ygrid >= ymin, ygrid <= ymax), tf.logical_and(ygrid >= ymin, ygrid <= ymax),
...@@ -289,7 +320,7 @@ def embedding_projection(x, y): ...@@ -289,7 +320,7 @@ def embedding_projection(x, y):
return dot return dot
def _get_2d_neighbors_kenel(): def _get_2d_neighbors_kernel():
"""Returns a conv. kernel that when applies generates 2D neighbors. """Returns a conv. kernel that when applies generates 2D neighbors.
Returns: Returns:
...@@ -311,20 +342,34 @@ def generate_2d_neighbors(input_tensor, dilation=2): ...@@ -311,20 +342,34 @@ def generate_2d_neighbors(input_tensor, dilation=2):
following ops on TPU won't have to pad the last dimension to 128. following ops on TPU won't have to pad the last dimension to 128.
Args: Args:
input_tensor: A float tensor of shape [height, width, channels]. input_tensor: A float tensor of shape [batch_size, height, width, channels].
dilation: int, the dilation factor for considering neighbors. dilation: int, the dilation factor for considering neighbors.
Returns: Returns:
output: A float tensor of all 8 2-D neighbors. of shape output: A float tensor of all 8 2-D neighbors. of shape
[8, height, width, channels]. [8, batch_size, height, width, channels].
""" """
input_tensor = tf.transpose(input_tensor, (2, 0, 1))
input_tensor = input_tensor[:, :, :, tf.newaxis]
kernel = _get_2d_neighbors_kenel() # TODO(vighneshb) Minimize tranposing here to save memory.
# input_tensor: [B, C, H, W]
input_tensor = tf.transpose(input_tensor, (0, 3, 1, 2))
# input_tensor: [B, C, H, W, 1]
input_tensor = input_tensor[:, :, :, :, tf.newaxis]
# input_tensor: [B * C, H, W, 1]
input_tensor, batch_size, channels = flatten_first2_dims(input_tensor)
kernel = _get_2d_neighbors_kernel()
# output: [B * C, H, W, 8]
output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation, output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation,
padding='SAME') padding='SAME')
return tf.transpose(output, [3, 1, 2, 0]) # output: [B, C, H, W, 8]
output = unpack_first2_dims(output, batch_size, channels)
# return: [8, B, H, W, C]
return tf.transpose(output, [4, 0, 2, 3, 1])
def gaussian_pixel_similarity(a, b, theta): def gaussian_pixel_similarity(a, b, theta):
...@@ -339,12 +384,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0): ...@@ -339,12 +384,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
[1]: https://arxiv.org/abs/2012.02310 [1]: https://arxiv.org/abs/2012.02310
Args: Args:
feature_map: A float tensor of shape [height, width, channels] feature_map: A float tensor of shape [batch_size, height, width, channels]
dilation: int, the dilation factor. dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian. theta: The denominator while taking difference inside the gaussian.
Returns: Returns:
dilated_similarity: A tensor of shape [8, height, width] dilated_similarity: A tensor of shape [8, batch_size, height, width]
""" """
neighbors = generate_2d_neighbors(feature_map, dilation) neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis] feature_map = feature_map[tf.newaxis]
...@@ -358,21 +403,26 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2): ...@@ -358,21 +403,26 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2):
[1]: https://arxiv.org/abs/2012.02310 [1]: https://arxiv.org/abs/2012.02310
Args: Args:
instance_masks: A float tensor of shape [num_instances, height, width] instance_masks: A float tensor of shape [batch_size, num_instances,
height, width]
dilation: int, the dilation factor. dilation: int, the dilation factor.
Returns: Returns:
dilated_same_label: A tensor of shape [8, num_instances, height, width] dilated_same_label: A tensor of shape [8, batch_size, num_instances,
height, width]
""" """
instance_masks = tf.transpose(instance_masks, (1, 2, 0)) # instance_masks: [batch_size, height, width, num_instances]
instance_masks = tf.transpose(instance_masks, (0, 2, 3, 1))
# neighbors: [8, batch_size, height, width, num_instances]
neighbors = generate_2d_neighbors(instance_masks, dilation) neighbors = generate_2d_neighbors(instance_masks, dilation)
# instance_masks = [1, batch_size, height, width, num_instances]
instance_masks = instance_masks[tf.newaxis] instance_masks = instance_masks[tf.newaxis]
same_mask_prob = ((instance_masks * neighbors) + same_mask_prob = ((instance_masks * neighbors) +
((1 - instance_masks) * (1 - neighbors))) ((1 - instance_masks) * (1 - neighbors)))
return tf.transpose(same_mask_prob, (0, 3, 1, 2)) return tf.transpose(same_mask_prob, (0, 1, 4, 2, 3))
def _per_pixel_single_conv(input_tensor, params, channels): def _per_pixel_single_conv(input_tensor, params, channels):
...@@ -722,6 +772,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer): ...@@ -722,6 +772,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
return tf.squeeze(out, axis=-1) return tf.squeeze(out, axis=-1)
def _batch_gt_list(gt_list):
return tf.stack(gt_list, axis=0)
def deepmac_proto_to_params(deepmac_config): def deepmac_proto_to_params(deepmac_config):
"""Convert proto to named tuple.""" """Convert proto to named tuple."""
...@@ -765,7 +819,8 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -765,7 +819,8 @@ def deepmac_proto_to_params(deepmac_config):
color_consistency_warmup_steps= color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps, deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start= color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start deepmac_config.color_consistency_warmup_start,
use_only_last_stage=deepmac_config.use_only_last_stage
) )
...@@ -808,8 +863,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -808,8 +863,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
f'pixel_embedding_dim({pixel_embedding_dim}) ' f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).') f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
super(DeepMACMetaArch, self).__init__( super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries, is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor, num_classes=num_classes, feature_extractor=feature_extractor,
...@@ -847,60 +900,62 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -847,60 +900,62 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Get the input to the mask network, given bounding boxes. """Get the input to the mask network, given bounding boxes.
Args: Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in boxes: A [batch_size, num_instances, 4] float tensor containing bounding
normalized coordinates. boxes in normalized coordinates.
pixel_embedding: A [height, width, embedding_size] float tensor pixel_embedding: A [batch_size, height, width, embedding_size] float
containing spatial pixel embeddings. tensor containing spatial pixel embeddings.
Returns: Returns:
embedding: A [num_instances, mask_height, mask_width, embedding_size + 2] embedding: A [batch_size, num_instances, mask_height, mask_width,
float tensor containing the inputs to the mask network. For each embedding_size + 2] float tensor containing the inputs to the mask
bounding box, we concatenate the normalized box coordinates to the network. For each bounding box, we concatenate the normalized box
cropped pixel embeddings. If predict_full_resolution_masks is set, coordinates to the cropped pixel embeddings. If
mask_height and mask_width are the same as height and width of predict_full_resolution_masks is set, mask_height and mask_width are
pixel_embedding. If not, mask_height and mask_width are the same as the same as height and width of pixel_embedding. If not, mask_height
mask_size. and mask_width are the same as mask_size.
""" """
num_instances = tf.shape(boxes)[0] batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_size = self._deepmac_params.mask_size mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
num_instances = tf.shape(boxes)[0] num_instances = tf.shape(boxes)[1]
pixel_embedding = pixel_embedding[tf.newaxis, :, :, :] pixel_embedding = pixel_embedding[:, tf.newaxis, :, :, :]
pixel_embeddings_processed = tf.tile(pixel_embedding, pixel_embeddings_processed = tf.tile(pixel_embedding,
[num_instances, 1, 1, 1]) [1, num_instances, 1, 1, 1])
image_shape = tf.shape(pixel_embeddings_processed) image_shape = tf.shape(pixel_embeddings_processed)
image_height, image_width = image_shape[1], image_shape[2] image_height, image_width = image_shape[2], image_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height), y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height),
tf.linspace(0.0, 1.0, image_width), tf.linspace(0.0, 1.0, image_width),
indexing='ij') indexing='ij')
blist = box_list.BoxList(boxes) ycenter = (boxes[:, :, 0] + boxes[:, :, 2]) / 2.0
ycenter, xcenter, _, _ = blist.get_center_coordinates_and_sizes() xcenter = (boxes[:, :, 1] + boxes[:, :, 3]) / 2.0
y_grid = y_grid[tf.newaxis, :, :] y_grid = y_grid[tf.newaxis, tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, :, :] x_grid = x_grid[tf.newaxis, tf.newaxis, :, :]
y_grid -= ycenter[:, tf.newaxis, tf.newaxis] y_grid -= ycenter[:, :, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, tf.newaxis, tf.newaxis] x_grid -= xcenter[:, :, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=3) coords = tf.stack([y_grid, x_grid], axis=4)
else: else:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False. # TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_processed = crop_and_resize_feature_map( embeddings = spatial_transform_ops.matmul_crop_and_resize(
pixel_embedding, boxes, mask_size) pixel_embedding, boxes, [mask_size, mask_size])
pixel_embeddings_processed = embeddings
mask_shape = tf.shape(pixel_embeddings_processed) mask_shape = tf.shape(pixel_embeddings_processed)
mask_height, mask_width = mask_shape[1], mask_shape[2] mask_height, mask_width = mask_shape[2], mask_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height), y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
tf.linspace(-1.0, 1.0, mask_width), tf.linspace(-1.0, 1.0, mask_width),
indexing='ij') indexing='ij')
coords = tf.stack([y_grid, x_grid], axis=2) coords = tf.stack([y_grid, x_grid], axis=2)
coords = coords[tf.newaxis, :, :, :] coords = coords[tf.newaxis, tf.newaxis, :, :, :]
coords = tf.tile(coords, [num_instances, 1, 1, 1]) coords = tf.tile(coords, [batch_size, num_instances, 1, 1, 1])
if self._deepmac_params.use_xy: if self._deepmac_params.use_xy:
return tf.concat([coords, pixel_embeddings_processed], axis=3) return tf.concat([coords, pixel_embeddings_processed], axis=4)
else: else:
return pixel_embeddings_processed return pixel_embeddings_processed
...@@ -908,43 +963,94 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -908,43 +963,94 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Return the instance embeddings from bounding box centers. """Return the instance embeddings from bounding box centers.
Args: Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The boxes: A [batch_size, num_instances, 4] float tensor holding bounding
coordinates are in normalized input space. boxes. The coordinates are in normalized input space.
instance_embedding: A [height, width, embedding_size] float tensor instance_embedding: A [batch_size, height, width, embedding_size] float
containing the instance embeddings. tensor containing the instance embeddings.
Returns: Returns:
instance_embeddings: A [num_instances, embedding_size] shaped float tensor instance_embeddings: A [batch_size, num_instances, embedding_size]
containing the center embedding for each instance. shaped float tensor containing the center embedding for each instance.
""" """
blist = box_list.BoxList(boxes)
output_height = tf.shape(instance_embedding)[0] output_height = tf.cast(tf.shape(instance_embedding)[1], tf.float32)
output_width = tf.shape(instance_embedding)[1] output_width = tf.cast(tf.shape(instance_embedding)[2], tf.float32)
ymin = boxes[:, :, 0]
blist_output = box_list_ops.to_absolute_coordinates( xmin = boxes[:, :, 1]
blist, output_height, output_width, check_range=False) ymax = boxes[:, :, 2]
(y_center_output, x_center_output, xmax = boxes[:, :, 3]
_, _) = blist_output.get_center_coordinates_and_sizes()
center_coords_output = tf.stack([y_center_output, x_center_output], axis=1) y_center_output = (ymin + ymax) * output_height / 2.0
x_center_output = (xmin + xmax) * output_width / 2.0
center_coords_output = tf.stack([y_center_output, x_center_output], axis=2)
center_coords_output_int = tf.cast(center_coords_output, tf.int32) center_coords_output_int = tf.cast(center_coords_output, tf.int32)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int,
batch_dims=1)
return center_latents return center_latents
def predict(self, preprocessed_inputs, other_inputs):
prediction_dict = super(DeepMACMetaArch, self).predict(
preprocessed_inputs, other_inputs)
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
return prediction_dict
def _predict_mask_logits_from_embeddings(
self, pixel_embedding, instance_embedding, boxes):
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
mask_input, batch_size, num_instances = flatten_first2_dims(mask_input)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
instance_embeddings, _, _ = flatten_first2_dims(instance_embeddings)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_logits = unpack_first2_dims(
mask_logits, batch_size, num_instances)
return mask_logits
def _predict_mask_logits_from_gt_boxes(self, prediction_dict):
mask_logits_list = []
boxes = _batch_gt_list(self.groundtruth_lists(fields.BoxListFields.boxes))
instance_embedding_list = prediction_dict[INSTANCE_EMBEDDING]
pixel_embedding_list = prediction_dict[PIXEL_EMBEDDING]
if self._deepmac_params.use_only_last_stage:
instance_embedding_list = [instance_embedding_list[-1]]
pixel_embedding_list = [pixel_embedding_list[-1]]
for (instance_embedding, pixel_embedding) in zip(instance_embedding_list,
pixel_embedding_list):
mask_logits_list.append(
self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes))
return mask_logits_list
def _get_groundtruth_mask_output(self, boxes, masks): def _get_groundtruth_mask_output(self, boxes, masks):
"""Get the expected mask output for each box. """Get the expected mask output for each box.
Args: Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in boxes: A [batch_size, num_instances, 4] float tensor containing bounding
normalized coordinates. boxes in normalized coordinates.
masks: A [num_instances, height, width] float tensor containing binary masks: A [batch_size, num_instances, height, width] float tensor
ground truth masks. containing binary ground truth masks.
Returns: Returns:
masks: If predict_full_resolution_masks is set, masks are not resized masks: If predict_full_resolution_masks is set, masks are not resized
and the size of this tensor is [num_instances, input_height, input_width]. and the size of this tensor is [batch_size, num_instances,
Otherwise, returns a tensor of size [num_instances, mask_size, mask_size]. input_height, input_width]. Otherwise, returns a tensor of size
[batch_size, num_instances, mask_size, mask_size].
""" """
mask_size = self._deepmac_params.mask_size mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
return masks return masks
...@@ -957,9 +1063,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -957,9 +1063,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return cropped_masks return cropped_masks
def _resize_logits_like_gt(self, logits, gt): def _resize_logits_like_gt(self, logits, gt):
height, width = tf.shape(gt)[2], tf.shape(gt)[3]
height, width = tf.shape(gt)[1], tf.shape(gt)[2]
return resize_instance_masks(logits, (height, width)) return resize_instance_masks(logits, (height, width))
def _aggregate_classification_loss(self, loss, gt, pred, method): def _aggregate_classification_loss(self, loss, gt, pred, method):
...@@ -1016,54 +1120,59 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1016,54 +1120,59 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
raise ValueError('Unknown loss aggregation - {}'.format(method)) raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_per_instance_mask_prediction_loss( def _compute_mask_prediction_loss(
self, boxes, mask_logits, mask_gt): self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss. """Compute the per-instance mask loss.
Args: Args:
boxes: A [num_instances, 4] float tensor of GT boxes. boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
mask_logits: A [num_instances, height, width] float tensor of predicted mask_logits: A [batch_suze, num_instances, height, width] float tensor of
masks predicted masks
mask_gt: The groundtruth mask. mask_gt: The groundtruth mask of same shape as mask_logits.
Returns: Returns:
loss: A [num_instances] shaped tensor with the loss for each instance. loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance.
""" """
num_instances = tf.shape(boxes)[0] batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt) mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1]) mask_logits = tf.reshape(mask_logits, [batch_size * num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1]) mask_gt = tf.reshape(mask_gt, [batch_size * num_instances, -1, 1])
loss = self._deepmac_params.classification_loss( loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits, prediction_tensor=mask_logits,
target_tensor=mask_gt, target_tensor=mask_gt,
weights=tf.ones_like(mask_logits)) weights=tf.ones_like(mask_logits))
return self._aggregate_classification_loss( loss = self._aggregate_classification_loss(
loss, mask_gt, mask_logits, 'normalize_auto') loss, mask_gt, mask_logits, 'normalize_auto')
return tf.reshape(loss, [batch_size, num_instances])
def _compute_per_instance_box_consistency_loss( def _compute_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits): self, boxes_gt, boxes_for_crop, mask_logits):
"""Compute the per-instance box consistency loss. """Compute the per-instance box consistency loss.
Args: Args:
boxes_gt: A [num_instances, 4] float tensor of GT boxes. boxes_gt: A [batch_size, num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [num_instances, 4] float tensor of augmented boxes, boxes_for_crop: A [batch_size, num_instances, 4] float tensor of
to be used when using crop-and-resize based mask head. augmented boxes, to be used when using crop-and-resize based mask head.
mask_logits: A [num_instances, height, width] float tensor of predicted mask_logits: A [batch_size, num_instances, height, width]
masks. float tensor of predicted masks.
Returns: Returns:
loss: A [num_instances] shaped tensor with the loss for each instance. loss: A [batch_size, num_instances] shaped tensor with the loss for
each instance in the batch.
""" """
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2] shape = tf.shape(mask_logits)
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis] batch_size, num_instances, height, width = (
mask_logits = mask_logits[:, :, :, tf.newaxis] shape[0], shape[1], shape[2], shape[3])
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, :, tf.newaxis]
if self._deepmac_params.predict_full_resolution_masks: if self._deepmac_params.predict_full_resolution_masks:
gt_crop = filled_boxes[:, :, :, 0] gt_crop = filled_boxes[:, :, :, :, 0]
pred_crop = mask_logits[:, :, :, 0] pred_crop = mask_logits[:, :, :, :, 0]
else: else:
gt_crop = crop_and_resize_instance_masks( gt_crop = crop_and_resize_instance_masks(
filled_boxes, boxes_for_crop, self._deepmac_params.mask_size) filled_boxes, boxes_for_crop, self._deepmac_params.mask_size)
...@@ -1071,7 +1180,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1071,7 +1180,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits, boxes_for_crop, self._deepmac_params.mask_size) mask_logits, boxes_for_crop, self._deepmac_params.mask_size)
loss = 0.0 loss = 0.0
for axis in [1, 2]: for axis in [2, 3]:
if self._deepmac_params.box_consistency_tightness: if self._deepmac_params.box_consistency_tightness:
pred_max_raw = tf.reduce_max(pred_crop, axis=axis) pred_max_raw = tf.reduce_max(pred_crop, axis=axis)
...@@ -1083,44 +1192,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1083,44 +1192,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else: else:
pred_max = tf.reduce_max(pred_crop, axis=axis) pred_max = tf.reduce_max(pred_crop, axis=axis)
pred_max = pred_max[:, :, tf.newaxis] pred_max = pred_max[:, :, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis] gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, :, tf.newaxis]
flat_pred, batch_size, num_instances = flatten_first2_dims(pred_max)
flat_gt, _, _ = flatten_first2_dims(gt_max)
# We use flat tensors while calling loss functions because we
# want the loss per-instance to later multiply with the per-instance
# weight. Flattening the first 2 dims allows us to represent each instance
# in each batch as though they were samples in a larger batch.
raw_loss = self._deepmac_params.classification_loss( raw_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max, prediction_tensor=flat_pred,
target_tensor=gt_max, target_tensor=flat_gt,
weights=tf.ones_like(pred_max)) weights=tf.ones_like(flat_pred))
loss += self._aggregate_classification_loss( agg_loss = self._aggregate_classification_loss(
raw_loss, gt_max, pred_max, raw_loss, flat_gt, flat_pred,
self._deepmac_params.box_consistency_loss_normalize) self._deepmac_params.box_consistency_loss_normalize)
loss += unpack_first2_dims(agg_loss, batch_size, num_instances)
return loss return loss
def _compute_per_instance_color_consistency_loss( def _compute_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits): self, boxes, preprocessed_image, mask_logits):
"""Compute the per-instance color consistency loss. """Compute the per-instance color consistency loss.
Args: Args:
boxes: A [num_instances, 4] float tensor of GT boxes. boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [height, width, 3] float tensor containing the preprocessed_image: A [batch_size, height, width, 3]
preprocessed image. float tensor containing the preprocessed image.
mask_logits: A [num_instances, height, width] float tensor of predicted mask_logits: A [batch_size, num_instances, height, width] float tensor of
masks. predicted masks.
Returns: Returns:
loss: A [num_instances] shaped tensor with the loss for each instance. loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance fpr each sample in the batch.
""" """
if not self._deepmac_params.predict_full_resolution_masks: if not self._deepmac_params.predict_full_resolution_masks:
logging.info('Color consistency is not implemented with RoIAlign ' logging.info('Color consistency is not implemented with RoIAlign '
', i.e, fixed sized masks. Returning 0 loss.') ', i.e, fixed sized masks. Returning 0 loss.')
return tf.zeros(tf.shape(boxes)[0]) return tf.zeros(tf.shape(boxes)[:2])
dilation = self._deepmac_params.color_consistency_dilation dilation = self._deepmac_params.color_consistency_dilation
height, width = (tf.shape(preprocessed_image)[0], height, width = (tf.shape(preprocessed_image)[1],
tf.shape(preprocessed_image)[1]) tf.shape(preprocessed_image)[2])
color_similarity = dilated_cross_pixel_similarity( color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0) preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits) mask_probs = tf.nn.sigmoid(mask_logits)
...@@ -1132,20 +1250,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1132,20 +1250,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
color_similarity_mask = ( color_similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold) color_similarity > self._deepmac_params.color_consistency_threshold)
color_similarity_mask = tf.cast( color_similarity_mask = tf.cast(
color_similarity_mask[:, tf.newaxis, :, :], tf.float32) color_similarity_mask[:, :, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask * per_pixel_loss = -(color_similarity_mask *
tf.math.log(same_mask_label_probability)) tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps. # TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width) box_mask = fill_boxes(boxes, height, width)
box_mask_expanded = box_mask[tf.newaxis, :, :, :] box_mask_expanded = box_mask[tf.newaxis]
per_pixel_loss = per_pixel_loss * box_mask_expanded per_pixel_loss = per_pixel_loss * box_mask_expanded
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 2, 3]) loss = tf.reduce_sum(per_pixel_loss, axis=[0, 3, 4])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2])) num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[2, 3]))
loss = loss / num_box_pixels loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
self._is_training): tf.keras.backend.learning_phase()):
training_step = tf.cast(self.training_step, tf.float32) training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast( warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32) self._deepmac_params.color_consistency_warmup_steps, tf.float32)
...@@ -1157,56 +1275,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1157,56 +1275,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return loss return loss
def _compute_per_instance_deepmac_losses( def _compute_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding, self, boxes, masks_logits, masks_gt, image):
image):
"""Returns the mask loss per instance. """Returns the mask loss per instance.
Args: Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The boxes: A [batch_size, num_instances, 4] float tensor holding bounding
coordinates are in normalized input space. boxes. The coordinates are in normalized input space.
masks: A [num_instances, input_height, input_width] float tensor masks_logits: A [batch_size, num_instances, input_height, input_width]
containing the instance masks. float tensor containing the instance mask predictions in their logit
instance_embedding: A [output_height, output_width, embedding_size] form.
float tensor containing the instance embeddings. masks_gt: A [batch_size, num_instances, input_height, input_width] float
pixel_embedding: optional [output_height, output_width, tensor containing the groundtruth masks.
pixel_embedding_size] float tensor containing the per-pixel embeddings. image: [batch_size, output_height, output_width, channels] float tensor
image: [output_height, output_width, channels] float tensor
denoting the input image. denoting the input image.
Returns: Returns:
mask_prediction_loss: A [num_instances] shaped float tensor containing the mask_prediction_loss: A [batch_size, num_instances] shaped float tensor
mask loss for each instance. containing the mask loss for each instance in the batch.
box_consistency_loss: A [num_instances] shaped float tensor containing box_consistency_loss: A [batch_size, num_instances] shaped float tensor
the box consistency loss for each instance. containing the box consistency loss for each instance in the batch.
box_consistency_loss: A [num_instances] shaped float tensor containing box_consistency_loss: A [batch_size, num_instances] shaped float tensor
the color consistency loss. containing the color consistency loss in the batch.
""" """
if tf.keras.backend.learning_phase(): if tf.keras.backend.learning_phase():
boxes_for_crop = preprocessor.random_jitter_boxes( boxes = tf.stop_gradient(boxes)
boxes, self._deepmac_params.max_roi_jitter_ratio, def jitter_func(boxes):
jitter_mode=self._deepmac_params.roi_jitter_mode) return preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
boxes_for_crop = tf.map_fn(jitter_func,
boxes, parallel_iterations=128)
else: else:
boxes_for_crop = boxes boxes_for_crop = boxes
mask_input = self._get_mask_head_input( mask_gt = self._get_groundtruth_mask_output(
boxes_for_crop, pixel_embedding) boxes_for_crop, masks_gt)
instance_embeddings = self._get_instance_embeddings(
boxes_for_crop, instance_embedding)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_gt = self._get_groundtruth_mask_output(boxes_for_crop, masks)
mask_prediction_loss = self._compute_per_instance_mask_prediction_loss( mask_prediction_loss = self._compute_mask_prediction_loss(
boxes_for_crop, mask_logits, mask_gt) boxes_for_crop, masks_logits, mask_gt)
box_consistency_loss = self._compute_per_instance_box_consistency_loss( box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, mask_logits) boxes, boxes_for_crop, masks_logits)
color_consistency_loss = self._compute_per_instance_color_consistency_loss( color_consistency_loss = self._compute_color_consistency_loss(
boxes, image, mask_logits) boxes, image, masks_logits)
return { return {
DEEP_MASK_ESTIMATION: mask_prediction_loss, DEEP_MASK_ESTIMATION: mask_prediction_loss,
...@@ -1224,7 +1339,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1224,7 +1339,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
' consistency loss is not supported in TF1.')) ' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image) return tfio.experimental.color.rgb_to_lab(raw_image)
def _compute_instance_masks_loss(self, prediction_dict): def _compute_masks_loss(self, prediction_dict):
"""Computes the mask loss. """Computes the mask loss.
Args: Args:
...@@ -1236,10 +1351,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1236,10 +1351,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
Returns: Returns:
loss_dict: A dict mapping string (loss names) to scalar floats. loss_dict: A dict mapping string (loss names) to scalar floats.
""" """
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
allowed_masked_classes_ids = ( allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids) self._deepmac_params.allowed_masked_classes_ids)
...@@ -1248,8 +1359,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1248,8 +1359,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for loss_name in MASK_LOSSES: for loss_name in MASK_LOSSES:
loss_dict[loss_name] = 0.0 loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0]) prediction_shape = tf.shape(prediction_dict[MASK_LOGITS_GT_BOXES][0])
height, width = prediction_shape[1], prediction_shape[2] height, width = prediction_shape[2], prediction_shape[3]
preprocessed_image = tf.image.resize( preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width)) prediction_dict['preprocessed_inputs'], (height, width))
...@@ -1258,42 +1369,46 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1258,42 +1369,46 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
# TODO(vighneshb) See if we can save memory by only using the final # TODO(vighneshb) See if we can save memory by only using the final
# prediction # prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2) # Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip(
prediction_dict[INSTANCE_EMBEDDING], gt_boxes = _batch_gt_list(
prediction_dict[PIXEL_EMBEDDING]): self.groundtruth_lists(fields.BoxListFields.boxes))
# Iterate over samples in batch gt_weights = _batch_gt_list(
# TODO(vighneshb) find out how autograph is handling this. Converting self.groundtruth_lists(fields.BoxListFields.weights))
# to a single op may give speed improvements gt_masks = _batch_gt_list(
for i, (boxes, weights, classes, masks) in enumerate( self.groundtruth_lists(fields.BoxListFields.masks))
zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)): gt_classes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.classes))
# TODO(vighneshb) Add sub-sampling back if required.
classes, valid_mask_weights, masks = filter_masked_classes( mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES]
allowed_masked_classes_ids, classes, weights, masks) for mask_logits in mask_logits_list:
sample_loss_dict = self._compute_per_instance_deepmac_losses( # TODO(vighneshb) Add sub-sampling back if required.
boxes, masks, instance_pred[i], pixel_pred[i], image[i]) _, valid_mask_weights, gt_masks = filter_masked_classes(
allowed_masked_classes_ids, gt_classes,
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights gt_weights, gt_masks)
for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= weights sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image)
num_instances = tf.maximum(tf.reduce_sum(weights), 1.0)
num_instances_allowed = tf.maximum( sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
tf.reduce_sum(valid_mask_weights), 1.0) for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= gt_weights
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) / num_instances = tf.maximum(tf.reduce_sum(gt_weights), 1.0)
num_instances_allowed) num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) / loss_dict[DEEP_MASK_ESTIMATION] += (
num_instances) tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) /
num_instances_allowed)
batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING]) for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) /
return dict((key, loss / float(batch_size * num_predictions)) num_instances)
num_predictions = len(mask_logits_list)
return dict((key, loss / float(num_predictions))
for key, loss in loss_dict.items()) for key, loss in loss_dict.items())
def loss(self, prediction_dict, true_image_shapes, scope=None): def loss(self, prediction_dict, true_image_shapes, scope=None):
...@@ -1302,7 +1417,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1302,7 +1417,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict, true_image_shapes, scope) prediction_dict, true_image_shapes, scope)
if self._deepmac_params is not None: if self._deepmac_params is not None:
mask_loss_dict = self._compute_instance_masks_loss( mask_loss_dict = self._compute_masks_loss(
prediction_dict=prediction_dict) prediction_dict=prediction_dict)
for loss_name in MASK_LOSSES: for loss_name in MASK_LOSSES:
...@@ -1363,50 +1478,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1363,50 +1478,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_size] containing binary per-box instance masks. mask_size] containing binary per-box instance masks.
""" """
def process(elems): height, width = (tf.shape(instance_embedding)[1],
boxes, instance_embedding, pixel_embedding = elems tf.shape(instance_embedding)[2])
return self._postprocess_sample(boxes, instance_embedding,
pixel_embedding)
max_instances = self._center_params.max_box_predictions
return tf.map_fn(process, [boxes_output_stride, instance_embedding,
pixel_embedding],
dtype=tf.float32, parallel_iterations=max_instances)
def _postprocess_sample(self, boxes_output_stride,
instance_embedding, pixel_embedding):
"""Post process masks for a single sample.
Args:
boxes_output_stride: A [num_instances, 4] float tensor containing
bounding boxes in the absolute output space.
instance_embedding: A [output_height, output_width, embedding_size]
float tensor containing instance embeddings.
pixel_embedding: A [batch_size, output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embedding.
Returns:
masks: A float tensor of size [num_instances, mask_height, mask_width]
containing binary per-box instance masks. If
predict_full_resolution_masks is set, the masks will be resized to
postprocess_crop_size. Otherwise, mask_height=mask_width=mask_size
"""
height, width = (tf.shape(instance_embedding)[0],
tf.shape(instance_embedding)[1])
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
blist = box_list.BoxList(boxes_output_stride) ymin, xmin, ymax, xmax = tf.unstack(boxes_output_stride, axis=2)
blist = box_list_ops.to_normalized_coordinates( ymin /= height
blist, height, width, check_range=False) ymax /= height
boxes = blist.get() xmin /= width
xmax /= width
mask_input = self._get_mask_head_input(boxes, pixel_embedding) boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
mask_logits = self._mask_net( mask_logits = self._predict_mask_logits_from_embeddings(
instance_embeddings, mask_input, pixel_embedding, instance_embedding, boxes)
training=tf.keras.backend.learning_phase())
# TODO(vighneshb) Explore sweeping mask thresholds. # TODO(vighneshb) Explore sweeping mask thresholds.
...@@ -1416,7 +1499,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -1416,7 +1499,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
height *= self._stride height *= self._stride
width *= self._stride width *= self._stride
mask_logits = resize_instance_masks(mask_logits, (height, width)) mask_logits = resize_instance_masks(mask_logits, (height, width))
mask_logits = crop_masks_within_boxes(
mask_logits = crop_and_resize_instance_masks(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size) mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
masks_prob = tf.nn.sigmoid(mask_logits) masks_prob = tf.nn.sigmoid(mask_logits)
......
...@@ -44,6 +44,7 @@ DEEPMAC_PROTO_TEXT = """ ...@@ -44,6 +44,7 @@ DEEPMAC_PROTO_TEXT = """
box_consistency_loss_normalize: NORMALIZE_AUTO box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20 color_consistency_warmup_steps: 20
color_consistency_warmup_start: 10 color_consistency_warmup_start: 10
use_only_last_stage: false
""" """
...@@ -117,10 +118,11 @@ def build_meta_arch(**override_params): ...@@ -117,10 +118,11 @@ def build_meta_arch(**override_params):
mask_size=16, mask_size=16,
postprocess_crop_size=128, postprocess_crop_size=128,
max_roi_jitter_ratio=0.0, max_roi_jitter_ratio=0.0,
roi_jitter_mode='random', roi_jitter_mode='default',
color_consistency_dilation=2, color_consistency_dilation=2,
color_consistency_warmup_steps=0, color_consistency_warmup_steps=0,
color_consistency_warmup_start=0) color_consistency_warmup_start=0,
use_only_last_stage=True)
params.update(override_params) params.update(override_params)
...@@ -185,6 +187,7 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -185,6 +187,7 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(params, deepmac_meta_arch.DeepMACParams) self.assertIsInstance(params, deepmac_meta_arch.DeepMACParams)
self.assertEqual(params.dim, 153) self.assertEqual(params.dim, 153)
self.assertEqual(params.box_consistency_loss_normalize, 'normalize_auto') self.assertEqual(params.box_consistency_loss_normalize, 'normalize_auto')
self.assertFalse(params.use_only_last_stage)
def test_subsample_trivial(self): def test_subsample_trivial(self):
"""Test subsampling masks.""" """Test subsampling masks."""
...@@ -201,32 +204,71 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -201,32 +204,71 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(result[2], boxes) self.assertAllClose(result[2], boxes)
self.assertAllClose(result[3], masks) self.assertAllClose(result[3], masks)
def test_filter_masked_classes(self):
classes = np.zeros((2, 3, 5), dtype=np.float32)
classes[0, 0] = [1.0, 0.0, 0.0, 0.0, 0.0]
classes[0, 1] = [0.0, 1.0, 0.0, 0.0, 0.0]
classes[0, 2] = [0.0, 0.0, 1.0, 0.0, 0.0]
classes[1, 0] = [0.0, 0.0, 0.0, 1.0, 0.0]
classes[1, 1] = [0.0, 0.0, 0.0, 0.0, 1.0]
classes[1, 2] = [0.0, 0.0, 0.0, 0.0, 1.0]
classes = tf.constant(classes)
weights = tf.constant([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0]])
masks = tf.ones((2, 3, 32, 32), dtype=tf.float32)
classes, weights, masks = deepmac_meta_arch.filter_masked_classes(
[3, 4], classes, weights, masks)
expected_classes = np.zeros((2, 3, 5))
expected_classes[0, 0] = [0.0, 0.0, 0.0, 0.0, 0.0]
expected_classes[0, 1] = [0.0, 0.0, 0.0, 0.0, 0.0]
expected_classes[0, 2] = [0.0, 0.0, 1.0, 0.0, 0.0]
expected_classes[1, 0] = [0.0, 0.0, 0.0, 1.0, 0.0]
expected_classes[1, 1] = [0.0, 0.0, 0.0, 0.0, 0.0]
expected_classes[1, 2] = [0.0, 0.0, 0.0, 0.0, 0.0]
self.assertAllClose(expected_classes, classes.numpy())
self.assertAllClose(np.array(([0.0, 0.0, 1.0], [1.0, 0.0, 0.0])), weights)
self.assertAllClose(masks[0, 0], np.zeros((32, 32)))
self.assertAllClose(masks[0, 1], np.zeros((32, 32)))
self.assertAllClose(masks[0, 2], np.ones((32, 32)))
self.assertAllClose(masks[1, 0], np.ones((32, 32)))
self.assertAllClose(masks[1, 1], np.zeros((32, 32)))
def test_fill_boxes(self): def test_fill_boxes(self):
boxes = tf.constant([[0., 0., 0.5, 0.5], [0.5, 0.5, 1.0, 1.0]]) boxes = tf.constant([[[0., 0., 0.5, 0.5], [0.5, 0.5, 1.0, 1.0]],
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
filled_boxes = deepmac_meta_arch.fill_boxes(boxes, 32, 32) filled_boxes = deepmac_meta_arch.fill_boxes(boxes, 32, 32)
expected = np.zeros((2, 32, 32)) expected = np.zeros((2, 2, 32, 32))
expected[0, :17, :17] = 1.0 expected[0, 0, :17, :17] = 1.0
expected[1, 16:, 16:] = 1.0 expected[0, 1, 16:, 16:] = 1.0
expected[1, 0, :, :] = 1.0
filled_boxes = filled_boxes.numpy()
self.assertAllClose(expected[0, 0], filled_boxes[0, 0], rtol=1e-3)
self.assertAllClose(expected[0, 1], filled_boxes[0, 1], rtol=1e-3)
self.assertAllClose(expected[1, 0], filled_boxes[1, 0], rtol=1e-3)
self.assertAllClose(expected, filled_boxes.numpy(), rtol=1e-3) def test_flatten_and_unpack(self):
t = tf.random.uniform((2, 3, 4, 5, 6))
flatten = tf.function(deepmac_meta_arch.flatten_first2_dims)
unpack = tf.function(deepmac_meta_arch.unpack_first2_dims)
result, d1, d2 = flatten(t)
result = unpack(result, d1, d2)
self.assertAllClose(result.numpy(), t)
def test_crop_and_resize_instance_masks(self): def test_crop_and_resize_instance_masks(self):
boxes = tf.zeros((5, 4)) boxes = tf.zeros((8, 5, 4))
masks = tf.zeros((5, 128, 128)) masks = tf.zeros((8, 5, 128, 128))
output = deepmac_meta_arch.crop_and_resize_instance_masks( output = deepmac_meta_arch.crop_and_resize_instance_masks(
masks, boxes, 32) masks, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32)) self.assertEqual(output.shape, (8, 5, 32, 32))
def test_crop_and_resize_feature_map(self):
boxes = tf.zeros((5, 4))
features = tf.zeros((128, 128, 7))
output = deepmac_meta_arch.crop_and_resize_feature_map(
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
def test_embedding_projection_prob_shape(self): def test_embedding_projection_prob_shape(self):
dist = deepmac_meta_arch.embedding_projection( dist = deepmac_meta_arch.embedding_projection(
...@@ -262,73 +304,75 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -262,73 +304,75 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_generate_2d_neighbors_shape(self): def test_generate_2d_neighbors_shape(self):
inp = tf.zeros((13, 14, 3)) inp = tf.zeros((5, 13, 14, 3))
out = deepmac_meta_arch.generate_2d_neighbors(inp) out = deepmac_meta_arch.generate_2d_neighbors(inp)
self.assertEqual((8, 13, 14, 3), out.shape) self.assertEqual((8, 5, 13, 14, 3), out.shape)
def test_generate_2d_neighbors(self): def test_generate_2d_neighbors(self):
inp = np.arange(16).reshape(4, 4).astype(np.float32) inp = np.arange(16).reshape(4, 4).astype(np.float32)
inp = tf.stack([inp, inp * 2], axis=2) inp = tf.stack([inp, inp * 2], axis=2)
inp = tf.reshape(inp, (1, 4, 4, 2))
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=1) out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=1)
self.assertEqual((8, 4, 4, 2), out.shape) self.assertEqual((8, 1, 4, 4, 2), out.shape)
for i in range(2): for i in range(2):
expected = np.array([0, 1, 2, 4, 6, 8, 9, 10]) * (i + 1) expected = np.array([0, 1, 2, 4, 6, 8, 9, 10]) * (i + 1)
self.assertAllEqual(out[:, 1, 1, i], expected) self.assertAllEqual(out[:, 0, 1, 1, i], expected)
expected = np.array([1, 2, 3, 5, 7, 9, 10, 11]) * (i + 1) expected = np.array([1, 2, 3, 5, 7, 9, 10, 11]) * (i + 1)
self.assertAllEqual(out[:, 1, 2, i], expected) self.assertAllEqual(out[:, 0, 1, 2, i], expected)
expected = np.array([4, 5, 6, 8, 10, 12, 13, 14]) * (i + 1) expected = np.array([4, 5, 6, 8, 10, 12, 13, 14]) * (i + 1)
self.assertAllEqual(out[:, 2, 1, i], expected) self.assertAllEqual(out[:, 0, 2, 1, i], expected)
expected = np.array([5, 6, 7, 9, 11, 13, 14, 15]) * (i + 1) expected = np.array([5, 6, 7, 9, 11, 13, 14, 15]) * (i + 1)
self.assertAllEqual(out[:, 2, 2, i], expected) self.assertAllEqual(out[:, 0, 2, 2, i], expected)
def test_generate_2d_neighbors_dilation2(self): def test_generate_2d_neighbors_dilation2(self):
inp = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
inp = np.arange(16).reshape(4, 4, 1).astype(np.float32)
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=2) out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=2)
self.assertEqual((8, 4, 4, 1), out.shape) self.assertEqual((8, 1, 4, 4, 1), out.shape)
expected = np.array([0, 0, 0, 0, 2, 0, 8, 10]) expected = np.array([0, 0, 0, 0, 2, 0, 8, 10])
self.assertAllEqual(out[:, 0, 0, 0], expected) self.assertAllEqual(out[:, 0, 0, 0, 0], expected)
def test_dilated_similarity_shape(self): def test_dilated_similarity_shape(self):
fmap = tf.zeros((5, 32, 32, 9))
fmap = tf.zeros((32, 32, 9))
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity( similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap) fmap)
self.assertEqual((8, 32, 32), similarity.shape) self.assertEqual((8, 5, 32, 32), similarity.shape)
def test_dilated_similarity(self): def test_dilated_similarity(self):
fmap = np.zeros((5, 5, 2), dtype=np.float32) fmap = np.zeros((1, 5, 5, 2), dtype=np.float32)
fmap[0, 0, :] = 1.0 fmap[0, 0, 0, :] = 1.0
fmap[4, 4, :] = 1.0 fmap[0, 4, 4, :] = 1.0
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity( similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap, theta=1.0, dilation=2) fmap, theta=1.0, dilation=2)
self.assertAlmostEqual(similarity.numpy()[0, 2, 2], self.assertAlmostEqual(similarity.numpy()[0, 0, 2, 2],
np.exp(-np.sqrt(2))) np.exp(-np.sqrt(2)))
def test_dilated_same_instance_mask_shape(self): def test_dilated_same_instance_mask_shape(self):
instances = tf.zeros((2, 5, 32, 32))
instances = tf.zeros((5, 32, 32))
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances) output = deepmac_meta_arch.dilated_cross_same_mask_label(instances)
self.assertEqual((8, 5, 32, 32), output.shape) self.assertEqual((8, 2, 5, 32, 32), output.shape)
def test_dilated_same_instance_mask(self): def test_dilated_same_instance_mask(self):
instances = np.zeros((3, 2, 5, 5), dtype=np.float32)
instances[0, 0, 0, 0] = 1.0
instances[0, 0, 2, 2] = 1.0
instances[0, 0, 4, 4] = 1.0
instances[2, 0, 0, 0] = 1.0
instances[2, 0, 2, 2] = 1.0
instances[2, 0, 4, 4] = 0.0
instances = np.zeros((2, 5, 5), dtype=np.float32)
instances[0, 0, 0] = 1.0
instances[0, 2, 2] = 1.0
instances[0, 4, 4] = 1.0
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances).numpy() output = deepmac_meta_arch.dilated_cross_same_mask_label(instances).numpy()
self.assertAllClose(np.ones((8, 5, 5)), output[:, 1, :, :]) self.assertAllClose(np.ones((8, 2, 5, 5)), output[:, 1, :, :])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 2, 2]) self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 0, 2, 2])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 0], output[:, 2, 0, 2, 2])
def test_per_pixel_single_conv_multiple_instance(self): def test_per_pixel_single_conv_multiple_instance(self):
...@@ -550,151 +594,184 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase): ...@@ -550,151 +594,184 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
# TODO(vighneshb): Add batch_size > 1 tests for loss functions.
def setUp(self): # pylint:disable=g-missing-super-call def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch() self.model = build_meta_arch()
def test_get_mask_head_input(self): def test_get_mask_head_input(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]], boxes = tf.constant([[[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
[[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]]],
dtype=tf.float32) dtype=tf.float32)
pixel_embedding = np.zeros((32, 32, 4), dtype=np.float32) pixel_embedding = np.zeros((2, 32, 32, 4), dtype=np.float32)
pixel_embedding[:16, :16] = 1.0 pixel_embedding[0, :16, :16] = 1.0
pixel_embedding[16:, 16:] = 2.0 pixel_embedding[0, 16:, 16:] = 2.0
pixel_embedding[1, :16, :16] = 3.0
pixel_embedding[1, 16:, 16:] = 4.0
pixel_embedding = tf.constant(pixel_embedding) pixel_embedding = tf.constant(pixel_embedding)
mask_inputs = self.model._get_mask_head_input(boxes, pixel_embedding) mask_inputs = self.model._get_mask_head_input(boxes, pixel_embedding)
self.assertEqual(mask_inputs.shape, (2, 16, 16, 6)) self.assertEqual(mask_inputs.shape, (2, 2, 16, 16, 6))
y_grid, x_grid = tf.meshgrid(np.linspace(-1.0, 1.0, 16), y_grid, x_grid = tf.meshgrid(np.linspace(-1.0, 1.0, 16),
np.linspace(-1.0, 1.0, 16), indexing='ij') np.linspace(-1.0, 1.0, 16), indexing='ij')
for i in range(2):
mask_input = mask_inputs[i] for i, j in ([0, 0], [0, 1], [1, 0], [1, 1]):
self.assertAllClose(y_grid, mask_input[:, :, 0]) self.assertAllClose(y_grid, mask_inputs[i, j, :, :, 0])
self.assertAllClose(x_grid, mask_input[:, :, 1]) self.assertAllClose(x_grid, mask_inputs[i, j, :, :, 1])
pixel_embedding = mask_input[:, :, 2:]
self.assertAllClose(np.zeros((16, 16, 4)) + i + 1, pixel_embedding) zeros = np.zeros((16, 16, 4))
self.assertAllClose(zeros + 1, mask_inputs[0, 0, :, :, 2:])
self.assertAllClose(zeros + 2, mask_inputs[0, 1, :, :, 2:])
self.assertAllClose(zeros + 3, mask_inputs[1, 0, :, :, 2:])
self.assertAllClose(zeros + 4, mask_inputs[1, 1, :, :, 2:])
def test_get_mask_head_input_no_crop_resize(self): def test_get_mask_head_input_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True) model = build_meta_arch(predict_full_resolution_masks=True)
boxes = tf.constant([[0., 0., 1.0, 1.0], [0.0, 0.0, 0.5, 1.0]], boxes = tf.constant([[[0., 0., 1.0, 1.0], [0.0, 0.0, 0.5, 1.0]],
dtype=tf.float32) [[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
pixel_embedding_np = np.random.randn(32, 32, 4).astype(np.float32) pixel_embedding_np = np.random.randn(2, 32, 32, 4).astype(np.float32)
pixel_embedding = tf.constant(pixel_embedding_np) pixel_embedding = tf.constant(pixel_embedding_np)
mask_inputs = model._get_mask_head_input(boxes, pixel_embedding) mask_inputs = model._get_mask_head_input(boxes, pixel_embedding)
self.assertEqual(mask_inputs.shape, (2, 32, 32, 6)) self.assertEqual(mask_inputs.shape, (2, 2, 32, 32, 6))
y_grid, x_grid = tf.meshgrid(np.linspace(.0, 1.0, 32), y_grid, x_grid = tf.meshgrid(np.linspace(.0, 1.0, 32),
np.linspace(.0, 1.0, 32), indexing='ij') np.linspace(.0, 1.0, 32), indexing='ij')
ys = [0.5, 0.25] self.assertAllClose(y_grid - 0.5, mask_inputs[0, 0, :, :, 0])
xs = [0.5, 0.5] self.assertAllClose(x_grid - 0.5, mask_inputs[0, 0, :, :, 1])
for i in range(2):
mask_input = mask_inputs[i] self.assertAllClose(y_grid - 0.25, mask_inputs[0, 1, :, :, 0])
self.assertAllClose(y_grid - ys[i], mask_input[:, :, 0]) self.assertAllClose(x_grid - 0.5, mask_inputs[0, 1, :, :, 1])
self.assertAllClose(x_grid - xs[i], mask_input[:, :, 1])
pixel_embedding = mask_input[:, :, 2:] self.assertAllClose(y_grid - 0.75, mask_inputs[1, 0, :, :, 0])
self.assertAllClose(pixel_embedding_np, pixel_embedding) self.assertAllClose(x_grid - 0.75, mask_inputs[1, 0, :, :, 1])
self.assertAllClose(y_grid, mask_inputs[1, 1, :, :, 0])
self.assertAllClose(x_grid, mask_inputs[1, 1, :, :, 1])
def test_get_instance_embeddings(self): def test_get_instance_embeddings(self):
embeddings = np.zeros((32, 32, 2)) embeddings = np.zeros((2, 32, 32, 2))
embeddings[8, 8] = 1.0 embeddings[0, 8, 8] = 1.0
embeddings[24, 16] = 2.0 embeddings[0, 24, 16] = 2.0
embeddings[1, 8, 16] = 3.0
embeddings = tf.constant(embeddings) embeddings = tf.constant(embeddings)
boxes = tf.constant([[0., 0., 0.5, 0.5], [0.5, 0.0, 1.0, 1.0]]) boxes = np.zeros((2, 2, 4), dtype=np.float32)
boxes[0, 0] = [0.0, 0.0, 0.5, 0.5]
boxes[0, 1] = [0.5, 0.0, 1.0, 1.0]
boxes[1, 0] = [0.0, 0.0, 0.5, 1.0]
boxes = tf.constant(boxes)
center_embeddings = self.model._get_instance_embeddings(boxes, embeddings) center_embeddings = self.model._get_instance_embeddings(boxes, embeddings)
self.assertAllClose(center_embeddings, [[1.0, 1.0], [2.0, 2.0]]) self.assertAllClose(center_embeddings[0, 0], [1.0, 1.0])
self.assertAllClose(center_embeddings[0, 1], [2.0, 2.0])
self.assertAllClose(center_embeddings[1, 0], [3.0, 3.0])
def test_get_groundtruth_mask_output(self): def test_get_groundtruth_mask_output(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]], boxes = np.zeros((2, 2, 4))
dtype=tf.float32) masks = np.zeros((2, 2, 32, 32))
masks = np.zeros((2, 32, 32), dtype=np.float32)
masks[0, :16, :16] = 0.5 boxes[0, 0] = [0.0, 0.0, 0.25, 0.25]
masks[1, 16:, 16:] = 0.1 boxes[0, 1] = [0.75, 0.75, 1.0, 1.0]
boxes[1, 0] = [0.0, 0.0, 0.5, 1.0]
masks = np.zeros((2, 2, 32, 32), dtype=np.float32)
masks[0, 0, :16, :16] = 0.5
masks[0, 1, 16:, 16:] = 0.1
masks[1, 0, :17, :] = 0.3
masks = self.model._get_groundtruth_mask_output(boxes, masks) masks = self.model._get_groundtruth_mask_output(boxes, masks)
self.assertEqual(masks.shape, (2, 16, 16)) self.assertEqual(masks.shape, (2, 2, 16, 16))
self.assertAllClose(masks[0], np.zeros((16, 16)) + 0.5) self.assertAllClose(masks[0, 0], np.zeros((16, 16)) + 0.5)
self.assertAllClose(masks[1], np.zeros((16, 16)) + 0.1) self.assertAllClose(masks[0, 1], np.zeros((16, 16)) + 0.1)
self.assertAllClose(masks[1, 0], np.zeros((16, 16)) + 0.3)
def test_get_groundtruth_mask_output_crop_resize(self): def test_get_groundtruth_mask_output_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True) model = build_meta_arch(predict_full_resolution_masks=True)
boxes = tf.constant([[0., 0., 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], boxes = tf.zeros((2, 5, 4))
dtype=tf.float32) masks = tf.ones((2, 5, 32, 32))
masks = tf.ones((2, 32, 32))
masks = model._get_groundtruth_mask_output(boxes, masks) masks = model._get_groundtruth_mask_output(boxes, masks)
self.assertAllClose(masks, np.ones((2, 32, 32))) self.assertAllClose(masks, np.ones((2, 5, 32, 32)))
def test_per_instance_loss(self): def test_predict(self):
model = build_meta_arch() tf.keras.backend.set_learning_phase(True)
model._mask_net = MockMaskNet() self.model.provide_groundtruth(
boxes = tf.constant([[0.0, 0.0, 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]]) groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)],
masks = np.zeros((2, 32, 32), dtype=np.float32) groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)],
masks[0, :16, :16] = 1.0 groundtruth_weights_list=[tf.ones(5)],
masks[1, 16:, 16:] = 1.0 groundtruth_masks_list=[tf.ones((5, 32, 32))])
masks = tf.constant(masks) prediction = self.model.predict(tf.zeros((1, 32, 32, 3)), None)
self.assertEqual(prediction['MASK_LOGITS_GT_BOXES'][0].shape,
(1, 5, 16, 16))
def test_loss(self):
loss_dict = model._compute_per_instance_deepmac_losses( model = build_meta_arch()
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)), boxes = tf.constant([[[0.0, 0.0, 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]]])
tf.zeros((16, 16, 3))) masks = np.zeros((1, 2, 32, 32), dtype=np.float32)
masks[0, 0, :16, :16] = 1.0
masks[0, 1, 16:, 16:] = 1.0
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 16, 16, 3)))
self.assertAllClose( self.assertAllClose(
loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION], loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9))) np.zeros((1, 2)) - tf.math.log(tf.nn.sigmoid(0.9)))
def test_per_instance_loss_no_crop_resize(self): def test_loss_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True) model = build_meta_arch(predict_full_resolution_masks=True)
model._mask_net = MockMaskNet() boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
boxes = tf.constant([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]) masks = tf.ones((1, 2, 128, 128), dtype=tf.float32)
masks = np.ones((2, 128, 128), dtype=np.float32) masks_pred = tf.fill((1, 2, 32, 32), 0.9)
masks = tf.constant(masks)
loss_dict = model._compute_per_instance_deepmac_losses( loss_dict = model._compute_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)), boxes, masks_pred, masks, tf.zeros((1, 32, 32, 3)))
tf.zeros((32, 32, 3)))
self.assertAllClose( self.assertAllClose(
loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION], loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9))) np.zeros((1, 2)) - tf.math.log(tf.nn.sigmoid(0.9)))
def test_per_instance_loss_no_crop_resize_dice(self): def test_loss_no_crop_resize_dice(self):
model = build_meta_arch(predict_full_resolution_masks=True, model = build_meta_arch(predict_full_resolution_masks=True,
use_dice_loss=True) use_dice_loss=True)
model._mask_net = MockMaskNet() boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
boxes = tf.constant([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]) masks = np.ones((1, 2, 128, 128), dtype=np.float32)
masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks) masks = tf.constant(masks)
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
loss_dict = model._compute_per_instance_deepmac_losses( loss_dict = model._compute_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)), boxes, masks_pred, masks, tf.zeros((1, 32, 32, 3)))
tf.zeros((32, 32, 3)))
pred = tf.nn.sigmoid(0.9) pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred))) expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
self.assertAllClose(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION], self.assertAllClose(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
[expected, expected], rtol=1e-3) [[expected, expected]], rtol=1e-3)
def test_empty_masks(self): def test_empty_masks(self):
boxes = tf.zeros([0, 4])
masks = tf.zeros([0, 128, 128])
loss_dict = self.model._compute_per_instance_deepmac_losses( boxes = tf.zeros([1, 0, 4])
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)), masks = tf.zeros([1, 0, 128, 128])
tf.zeros((16, 16, 3)))
loss_dict = self.model._compute_deepmac_losses(
boxes, masks, masks,
tf.zeros((1, 16, 16, 3)))
self.assertEqual(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION].shape, self.assertEqual(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION].shape,
(0,)) (1, 0))
def test_postprocess(self): def test_postprocess(self):
model = build_meta_arch() model = build_meta_arch()
model._mask_net = MockMaskNet() model._mask_net = MockMaskNet()
boxes = np.zeros((2, 3, 4), dtype=np.float32) boxes = np.zeros((2, 3, 4), dtype=np.float32)
...@@ -708,7 +785,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -708,7 +785,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16))) self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16)))
def test_postprocess_emb_proj(self): def test_postprocess_emb_proj(self):
model = build_meta_arch(network_type='embedding_projection', model = build_meta_arch(network_type='embedding_projection',
use_instance_embedding=False, use_instance_embedding=False,
use_xy=False, pixel_embedding_dim=8, use_xy=False, pixel_embedding_dim=8,
...@@ -724,7 +800,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -724,7 +800,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(masks.shape, (2, 3, 16, 16)) self.assertEqual(masks.shape, (2, 3, 16, 16))
def test_postprocess_emb_proj_fullres(self): def test_postprocess_emb_proj_fullres(self):
model = build_meta_arch(network_type='embedding_projection', model = build_meta_arch(network_type='embedding_projection',
predict_full_resolution_masks=True, predict_full_resolution_masks=True,
use_instance_embedding=False, use_instance_embedding=False,
...@@ -751,17 +826,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -751,17 +826,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
prob = tf.nn.sigmoid(0.9).numpy() prob = tf.nn.sigmoid(0.9).numpy()
self.assertAllClose(masks, prob * np.ones((2, 3, 128, 128))) self.assertAllClose(masks, prob * np.ones((2, 3, 128, 128)))
def test_crop_masks_within_boxes(self):
masks = np.zeros((2, 32, 32))
masks[0, :16, :16] = 1.0
masks[1, 16:, 16:] = 1.0
boxes = tf.constant([[0.0, 0.0, 15.0 / 32, 15.0 / 32],
[0.5, 0.5, 1.0, 1]])
masks = deepmac_meta_arch.crop_masks_within_boxes(
masks, boxes, 128)
masks = (masks.numpy() > 0.0).astype(np.float32)
self.assertAlmostEqual(masks.sum(), 2 * 128 * 128)
def test_transform_boxes_to_feature_coordinates(self): def test_transform_boxes_to_feature_coordinates(self):
batch_size = 2 batch_size = 2
model = build_meta_arch() model = build_meta_arch()
...@@ -816,13 +880,13 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -816,13 +880,13 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def test_box_consistency_loss(self): def test_box_consistency_loss(self):
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]]) boxes_gt = tf.constant([[[0., 0., 0.49, 1.0]]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]]) boxes_jittered = tf.constant([[[0.0, 0.0, 1.0, 1.0]]])
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32) mask_prediction = np.zeros((1, 1, 32, 32)).astype(np.float32)
mask_prediction[0, :24, :24] = 1.0 mask_prediction[0, 0, :24, :24] = 1.0
loss = self.model._compute_per_instance_box_consistency_loss( loss = self.model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits( yloss = tf.nn.sigmoid_cross_entropy_with_logits(
...@@ -834,39 +898,39 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -834,39 +898,39 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
yloss_mean = tf.reduce_mean(yloss) yloss_mean = tf.reduce_mean(yloss)
xloss_mean = tf.reduce_mean(xloss) xloss_mean = tf.reduce_mean(xloss)
self.assertAllClose(loss, [yloss_mean + xloss_mean]) self.assertAllClose(loss[0], [yloss_mean + xloss_mean])
def test_box_consistency_loss_with_tightness(self): def test_box_consistency_loss_with_tightness(self):
boxes_gt = tf.constant([[0., 0., 0.49, 0.49]]) boxes_gt = tf.constant([[[0., 0., 0.49, 0.49]]])
boxes_jittered = None boxes_jittered = None
mask_prediction = np.zeros((1, 8, 8)).astype(np.float32) - 1e10 mask_prediction = np.zeros((1, 1, 8, 8)).astype(np.float32) - 1e10
mask_prediction[0, :4, :4] = 1e10 mask_prediction[0, 0, :4, :4] = 1e10
model = build_meta_arch(box_consistency_tightness=True, model = build_meta_arch(box_consistency_tightness=True,
predict_full_resolution_masks=True) predict_full_resolution_masks=True)
loss = model._compute_per_instance_box_consistency_loss( loss = model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAllClose(loss, [0.0]) self.assertAllClose(loss[0], [0.0])
def test_box_consistency_loss_gt_count(self): def test_box_consistency_loss_gt_count(self):
boxes_gt = tf.constant([ boxes_gt = tf.constant([[
[0., 0., 1.0, 1.0], [0., 0., 1.0, 1.0],
[0., 0., 0.49, 0.49]]) [0., 0., 0.49, 0.49]]])
boxes_jittered = None boxes_jittered = None
mask_prediction = np.zeros((2, 32, 32)).astype(np.float32) mask_prediction = np.zeros((1, 2, 32, 32)).astype(np.float32)
mask_prediction[0, :16, :16] = 1.0 mask_prediction[0, 0, :16, :16] = 1.0
mask_prediction[1, :8, :8] = 1.0 mask_prediction[0, 1, :8, :8] = 1.0
model = build_meta_arch( model = build_meta_arch(
box_consistency_loss_normalize='normalize_groundtruth_count', box_consistency_loss_normalize='normalize_groundtruth_count',
predict_full_resolution_masks=True) predict_full_resolution_masks=True)
loss_func = tf.function( loss_func = (
model._compute_per_instance_box_consistency_loss) model._compute_box_consistency_loss)
loss = loss_func( loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
...@@ -877,7 +941,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -877,7 +941,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
xloss = yloss xloss = yloss
xloss_mean = tf.reduce_sum(xloss) xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[0], yloss_mean + xloss_mean) self.assertAllClose(loss[0, 0], yloss_mean + xloss_mean)
yloss = tf.nn.sigmoid_cross_entropy_with_logits( yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16 + [0.0] * 16), labels=tf.constant([1.0] * 16 + [0.0] * 16),
...@@ -885,21 +949,20 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -885,21 +949,20 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
yloss_mean = tf.reduce_sum(yloss) yloss_mean = tf.reduce_sum(yloss)
xloss = yloss xloss = yloss
xloss_mean = tf.reduce_sum(xloss) xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[1], yloss_mean + xloss_mean) self.assertAllClose(loss[0, 1], yloss_mean + xloss_mean)
def test_box_consistency_loss_balanced(self): def test_box_consistency_loss_balanced(self):
boxes_gt = tf.constant([[
boxes_gt = tf.constant([ [0., 0., 0.49, 0.49]]])
[0., 0., 0.49, 0.49]])
boxes_jittered = None boxes_jittered = None
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32) mask_prediction = np.zeros((1, 1, 32, 32)).astype(np.float32)
mask_prediction[0] = 1.0 mask_prediction[0, 0] = 1.0
model = build_meta_arch(box_consistency_loss_normalize='normalize_balanced', model = build_meta_arch(box_consistency_loss_normalize='normalize_balanced',
predict_full_resolution_masks=True) predict_full_resolution_masks=True)
loss_func = tf.function( loss_func = tf.function(
model._compute_per_instance_box_consistency_loss) model._compute_box_consistency_loss)
loss = loss_func( loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
...@@ -909,63 +972,64 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -909,63 +972,64 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
yloss_mean = tf.reduce_sum(yloss) / 16.0 yloss_mean = tf.reduce_sum(yloss) / 16.0
xloss_mean = yloss_mean xloss_mean = yloss_mean
self.assertAllClose(loss[0], yloss_mean + xloss_mean) self.assertAllClose(loss[0, 0], yloss_mean + xloss_mean)
def test_box_consistency_dice_loss(self): def test_box_consistency_dice_loss(self):
model = build_meta_arch(use_dice_loss=True) model = build_meta_arch(use_dice_loss=True)
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]]) boxes_gt = tf.constant([[[0., 0., 0.49, 1.0]]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]]) boxes_jittered = tf.constant([[[0.0, 0.0, 1.0, 1.0]]])
almost_inf = 1e10 almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32) mask_prediction = np.full((1, 1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :24, :24] = almost_inf mask_prediction[0, 0, :24, :24] = almost_inf
loss = model._compute_per_instance_box_consistency_loss( loss = model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = 1 - 6.0 / 7 yloss = 1 - 6.0 / 7
xloss = 0.2 xloss = 0.2
self.assertAllClose(loss, [yloss + xloss]) self.assertAllClose(loss, [[yloss + xloss]])
def test_color_consistency_loss_full_res_shape(self): def test_color_consistency_loss_full_res_shape(self):
model = build_meta_arch(use_dice_loss=True, model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True) predict_full_resolution_masks=True)
boxes = tf.zeros((3, 4)) boxes = tf.zeros((5, 3, 4))
img = tf.zeros((32, 32, 3)) img = tf.zeros((5, 32, 32, 3))
mask_logits = tf.zeros((3, 32, 32)) mask_logits = tf.zeros((5, 3, 32, 32))
loss = model._compute_per_instance_color_consistency_loss( loss = model._compute_color_consistency_loss(
boxes, img, mask_logits) boxes, img, mask_logits)
self.assertEqual([3], loss.shape) self.assertEqual([5, 3], loss.shape)
def test_color_consistency_1_threshold(self): def test_color_consistency_1_threshold(self):
model = build_meta_arch(predict_full_resolution_masks=True, model = build_meta_arch(predict_full_resolution_masks=True,
color_consistency_threshold=0.99) color_consistency_threshold=0.99)
boxes = tf.zeros((3, 4)) boxes = tf.zeros((5, 3, 4))
img = tf.zeros((32, 32, 3)) img = tf.zeros((5, 32, 32, 3))
mask_logits = tf.zeros((3, 32, 32)) - 1e4 mask_logits = tf.zeros((5, 3, 32, 32)) - 1e4
loss = model._compute_per_instance_color_consistency_loss( loss = model._compute_color_consistency_loss(
boxes, img, mask_logits) boxes, img, mask_logits)
self.assertAllClose(loss, np.zeros(3)) self.assertAllClose(loss, np.zeros((5, 3)))
def test_box_consistency_dice_loss_full_res(self): def test_box_consistency_dice_loss_full_res(self):
model = build_meta_arch(use_dice_loss=True, model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True) predict_full_resolution_masks=True)
boxes_gt = tf.constant([[0., 0., 1.0, 1.0]]) boxes_gt = tf.constant([[[0., 0., 1.0, 1.0]]])
boxes_jittered = None boxes_jittered = None
size = 32
almost_inf = 1e10 almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32) mask_prediction = np.full((1, 1, size, size), -almost_inf, dtype=np.float32)
mask_prediction[0, :16, :32] = almost_inf mask_prediction[0, 0, :(size // 2), :] = almost_inf
loss = model._compute_per_instance_box_consistency_loss( loss = model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction)) boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAlmostEqual(loss[0].numpy(), 1 / 3) self.assertAlmostEqual(loss[0, 0].numpy(), 1 / 3)
def test_get_lab_image_shape(self): def test_get_lab_image_shape(self):
...@@ -975,18 +1039,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -975,18 +1039,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def test_loss_keys(self): def test_loss_keys(self):
model = build_meta_arch(use_dice_loss=True) model = build_meta_arch(use_dice_loss=True)
prediction = { prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)), 'preprocessed_inputs': tf.random.normal((3, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 17))] * 2, 'MASK_LOGITS_GT_BOXES': [tf.random.normal((3, 5, 8, 8))] * 2,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 19))] * 2, 'object_center': [tf.random.normal((3, 8, 8, 6))] * 2,
'object_center': [tf.random.normal((1, 8, 8, 6))] * 2, 'box/offset': [tf.random.normal((3, 8, 8, 2))] * 2,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * 2, 'box/scale': [tf.random.normal((3, 8, 8, 2))] * 2
'box/scale': [tf.random.normal((1, 8, 8, 2))] * 2
} }
model.provide_groundtruth( model.provide_groundtruth(
groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)], groundtruth_boxes_list=[
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)], tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] * 3,
groundtruth_weights_list=[tf.ones(5)], groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)] * 3,
groundtruth_masks_list=[tf.ones((5, 32, 32))]) groundtruth_weights_list=[tf.ones(5)] * 3,
groundtruth_masks_list=[tf.ones((5, 32, 32))] * 3)
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]])) loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0) self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
...@@ -1008,8 +1072,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1008,8 +1072,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
num_stages = 1 num_stages = 1
prediction = { prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)), 'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages, 'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages, 'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages, 'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages 'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
...@@ -1066,6 +1129,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1066,6 +1129,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
f'{mask_loss} did not respond to change in weight.') f'{mask_loss} did not respond to change in weight.')
def test_color_consistency_warmup(self): def test_color_consistency_warmup(self):
tf.keras.backend.set_learning_phase(True)
model = build_meta_arch( model = build_meta_arch(
use_dice_loss=True, use_dice_loss=True,
predict_full_resolution_masks=True, predict_full_resolution_masks=True,
...@@ -1079,8 +1143,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase): ...@@ -1079,8 +1143,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
num_stages = 1 num_stages = 1
prediction = { prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)), 'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages, 'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages, 'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages, 'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages 'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
......
...@@ -403,7 +403,7 @@ message CenterNet { ...@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613 // Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24 // Next ID 25
message DeepMACMaskEstimation { message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions. // The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1; optional ClassificationLoss classification_loss = 1;
...@@ -485,6 +485,14 @@ message CenterNet { ...@@ -485,6 +485,14 @@ message CenterNet {
optional int32 color_consistency_warmup_start = 23 [default=0]; optional int32 color_consistency_warmup_start = 23 [default=0];
// DeepMAC has been refactored to process the entire batch at once,
// instead of the previous (simple) approach of processing one sample at
// a time. Because of this, the memory consumption has increased and
// it's crucial to only feed the mask head the last stage outputs
// from the hourglass. Doing so halves the memory requirement of the
// mask head and does not cause a drop in evaluation metrics.
optional bool use_only_last_stage = 24 [default=false];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment