Commit 457bcb85 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Refactor DeepMAC to process full batch and return mask logits with predict()

PiperOrigin-RevId: 426181961
parent c3f2134b
......@@ -1050,6 +1050,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
else:
raise ValueError(f'Unknown heatmap type - {self._box_heatmap_type}')
heatmap = tf.stop_gradient(heatmap)
heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch.
......
......@@ -30,6 +30,7 @@ if tf_version.is_tf2():
INSTANCE_EMBEDDING = 'INSTANCE_EMBEDDING'
PIXEL_EMBEDDING = 'PIXEL_EMBEDDING'
MASK_LOGITS_GT_BOXES = 'MASK_LOGITS_GT_BOXES'
DEEP_MASK_ESTIMATION = 'deep_mask_estimation'
DEEP_MASK_BOX_CONSISTENCY = 'deep_mask_box_consistency'
DEEP_MASK_COLOR_CONSISTENCY = 'deep_mask_color_consistency'
......@@ -50,7 +51,8 @@ DeepMACParams = collections.namedtuple('DeepMACParams', [
'box_consistency_loss_weight', 'color_consistency_threshold',
'color_consistency_dilation', 'color_consistency_loss_weight',
'box_consistency_loss_normalize', 'box_consistency_tightness',
'color_consistency_warmup_steps', 'color_consistency_warmup_start'
'color_consistency_warmup_steps', 'color_consistency_warmup_start',
'use_only_last_stage'
])
......@@ -140,33 +142,24 @@ def _get_deepmac_network_by_type(name, num_init_channels, mask_size=None):
raise ValueError('Unknown network type {}'.format(name))
def crop_masks_within_boxes(masks, boxes, output_size):
"""Crops masks to lie tightly within the boxes.
Args:
masks: A [num_instances, height, width] float tensor of masks.
boxes: A [num_instances, 4] sized tensor of normalized bounding boxes.
output_size: The height and width of the output masks.
Returns:
masks: A [num_instances, output_size, output_size] tensor of masks which
are cropped to be tightly within the gives boxes and resized.
"""
masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[output_size, output_size])
return masks[:, 0, :, :, 0]
def _resize_instance_masks_non_empty(masks, shape):
"""Resize a non-empty tensor of masks to the given shape."""
height, width = shape
flattened_masks, batch_size, num_instances = flatten_first2_dims(masks)
flattened_masks = flattened_masks[:, :, :, tf.newaxis]
flattened_masks = tf.image.resize(
flattened_masks, (height, width),
method=tf.image.ResizeMethod.BILINEAR)
return unpack_first2_dims(
flattened_masks[:, :, :, 0], batch_size, num_instances)
def resize_instance_masks(masks, shape):
height, width = shape
masks_ex = masks[:, :, :, tf.newaxis]
masks_ex = tf.image.resize(masks_ex, (height, width),
method=tf.image.ResizeMethod.BILINEAR)
masks = masks_ex[:, :, :, 0]
return masks
batch_size, num_instances = tf.shape(masks)[0], tf.shape(masks)[1]
return tf.cond(
tf.shape(masks)[1] == 0,
lambda: tf.zeros((batch_size, num_instances, shape[0], shape[1])),
lambda: _resize_instance_masks_non_empty(masks, shape))
def filter_masked_classes(masked_class_ids, classes, weights, masks):
......@@ -175,94 +168,132 @@ def filter_masked_classes(masked_class_ids, classes, weights, masks):
Args:
masked_class_ids: A list of class IDs allowed to have masks. These class IDs
are 1-indexed.
classes: A [num_instances, num_classes] float tensor containing the one-hot
encoded classes.
weights: A [num_instances] float tensor containing the weights of each
sample.
masks: A [num_instances, height, width] tensor containing the mask per
instance.
classes: A [batch_size, num_instances, num_classes] float tensor containing
the one-hot encoded classes.
weights: A [batch_size, num_instances] float tensor containing the weights
of each sample.
masks: A [batch_size, num_instances, height, width] tensor containing the
mask per instance.
Returns:
classes_filtered: A [num_instances, num_classes] float tensor containing the
one-hot encoded classes with classes not in masked_class_ids zeroed out.
weights_filtered: A [num_instances] float tensor containing the weights of
each sample with instances whose classes aren't in masked_class_ids
zeroed out.
masks_filtered: A [num_instances, height, width] tensor containing the mask
per instance with masks not belonging to masked_class_ids zeroed out.
classes_filtered: A [batch_size, num_instances, num_classes] float tensor
containing the one-hot encoded classes with classes not in
masked_class_ids zeroed out.
weights_filtered: A [batch_size, num_instances] float tensor containing the
weights of each sample with instances whose classes aren't in
masked_class_ids zeroed out.
masks_filtered: A [batch_size, num_instances, height, width] tensor
containing the mask per instance with masks not belonging to
masked_class_ids zeroed out.
"""
if len(masked_class_ids) == 0: # pylint:disable=g-explicit-length-test
return classes, weights, masks
if tf.shape(classes)[0] == 0:
if tf.shape(classes)[1] == 0:
return classes, weights, masks
masked_class_ids = tf.constant(np.array(masked_class_ids, dtype=np.int32))
label_id_offset = 1
masked_class_ids -= label_id_offset
class_ids = tf.argmax(classes, axis=1, output_type=tf.int32)
class_ids = tf.argmax(classes, axis=2, output_type=tf.int32)
matched_classes = tf.equal(
class_ids[:, tf.newaxis], masked_class_ids[tf.newaxis, :]
class_ids[:, :, tf.newaxis], masked_class_ids[tf.newaxis, tf.newaxis, :]
)
matched_classes = tf.reduce_any(matched_classes, axis=1)
matched_classes = tf.reduce_any(matched_classes, axis=2)
matched_classes = tf.cast(matched_classes, tf.float32)
return (
classes * matched_classes[:, tf.newaxis],
classes * matched_classes[:, :, tf.newaxis],
weights * matched_classes,
masks * matched_classes[:, tf.newaxis, tf.newaxis]
masks * matched_classes[:, :, tf.newaxis, tf.newaxis]
)
def crop_and_resize_feature_map(features, boxes, size):
"""Crop and resize regions from a single feature map given a set of boxes.
def flatten_first2_dims(tensor):
"""Flatten first 2 dimensions of a tensor.
Args:
features: A [H, W, C] float tensor.
boxes: A [N, 4] tensor of norrmalized boxes.
size: int, the size of the output features.
tensor: A tensor with shape [M, N, ....]
Returns:
per_box_features: A [N, size, size, C] tensor of cropped and resized
features.
flattened_tensor: A tensor of shape [M * N, ...]
M: int, the length of the first dimension of the input.
N: int, the length of the second dimension of the input.
"""
return spatial_transform_ops.matmul_crop_and_resize(
features[tf.newaxis], boxes[tf.newaxis], [size, size])[0]
shape = tf.shape(tensor)
d1, d2, rest = shape[0], shape[1], shape[2:]
tensor = tf.reshape(
tensor, tf.concat([[d1 * d2], rest], axis=0))
return tensor, d1, d2
def unpack_first2_dims(tensor, dim1, dim2):
"""Unpack the flattened first dimension of the tensor into 2 dimensions.
Args:
tensor: A tensor of shape [dim1 * dim2, ...]
dim1: int, the size of the first dimension.
dim2: int, the size of the second dimension.
Returns:
unflattened_tensor: A tensor of shape [dim1, dim2, ...].
"""
shape = tf.shape(tensor)
result_shape = tf.concat([[dim1, dim2], shape[1:]], axis=0)
return tf.reshape(tensor, result_shape)
def crop_and_resize_instance_masks(masks, boxes, mask_size):
"""Crop and resize each mask according to the given boxes.
Args:
masks: A [N, H, W] float tensor.
boxes: A [N, 4] float tensor of normalized boxes.
masks: A [B, N, H, W] float tensor.
boxes: A [B, N, 4] float tensor of normalized boxes.
mask_size: int, the size of the output masks.
Returns:
masks: A [N, mask_size, mask_size] float tensor of cropped and resized
masks: A [B, N, mask_size, mask_size] float tensor of cropped and resized
instance masks.
"""
masks, batch_size, num_instances = flatten_first2_dims(masks)
boxes, _, _ = flatten_first2_dims(boxes)
cropped_masks = spatial_transform_ops.matmul_crop_and_resize(
masks[:, :, :, tf.newaxis], boxes[:, tf.newaxis, :],
[mask_size, mask_size])
cropped_masks = tf.squeeze(cropped_masks, axis=[1, 4])
return cropped_masks
return unpack_first2_dims(cropped_masks, batch_size, num_instances)
def fill_boxes(boxes, height, width):
"""Fills the area included in the box."""
blist = box_list.BoxList(boxes)
blist = box_list_ops.to_absolute_coordinates(blist, height, width)
boxes = blist.get()
"""Fills the area included in the boxes with 1s.
Args:
boxes: A [batch_size, num_instances, 4] shapes float tensor of boxes given
in the normalized coordinate space.
height: int, height of the output image.
width: int, width of the output image.
Returns:
filled_boxes: A [batch_size, num_instances, height, width] shaped float
tensor with 1s in the area that falls inside each box.
"""
ymin, xmin, ymax, xmax = tf.unstack(
boxes[:, tf.newaxis, tf.newaxis, :], 4, axis=3)
boxes[:, :, tf.newaxis, tf.newaxis, :], 4, axis=4)
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
ymin *= height
ymax *= height
xmin *= width
xmax *= width
ygrid, xgrid = tf.meshgrid(tf.range(height), tf.range(width), indexing='ij')
ygrid, xgrid = tf.cast(ygrid, tf.float32), tf.cast(xgrid, tf.float32)
ygrid, xgrid = ygrid[tf.newaxis, :, :], xgrid[tf.newaxis, :, :]
ygrid, xgrid = (ygrid[tf.newaxis, tf.newaxis, :, :],
xgrid[tf.newaxis, tf.newaxis, :, :])
filled_boxes = tf.logical_and(
tf.logical_and(ygrid >= ymin, ygrid <= ymax),
......@@ -289,7 +320,7 @@ def embedding_projection(x, y):
return dot
def _get_2d_neighbors_kenel():
def _get_2d_neighbors_kernel():
"""Returns a conv. kernel that when applies generates 2D neighbors.
Returns:
......@@ -311,20 +342,34 @@ def generate_2d_neighbors(input_tensor, dilation=2):
following ops on TPU won't have to pad the last dimension to 128.
Args:
input_tensor: A float tensor of shape [height, width, channels].
input_tensor: A float tensor of shape [batch_size, height, width, channels].
dilation: int, the dilation factor for considering neighbors.
Returns:
output: A float tensor of all 8 2-D neighbors. of shape
[8, height, width, channels].
[8, batch_size, height, width, channels].
"""
input_tensor = tf.transpose(input_tensor, (2, 0, 1))
input_tensor = input_tensor[:, :, :, tf.newaxis]
kernel = _get_2d_neighbors_kenel()
# TODO(vighneshb) Minimize tranposing here to save memory.
# input_tensor: [B, C, H, W]
input_tensor = tf.transpose(input_tensor, (0, 3, 1, 2))
# input_tensor: [B, C, H, W, 1]
input_tensor = input_tensor[:, :, :, :, tf.newaxis]
# input_tensor: [B * C, H, W, 1]
input_tensor, batch_size, channels = flatten_first2_dims(input_tensor)
kernel = _get_2d_neighbors_kernel()
# output: [B * C, H, W, 8]
output = tf.nn.atrous_conv2d(input_tensor, kernel, rate=dilation,
padding='SAME')
return tf.transpose(output, [3, 1, 2, 0])
# output: [B, C, H, W, 8]
output = unpack_first2_dims(output, batch_size, channels)
# return: [8, B, H, W, C]
return tf.transpose(output, [4, 0, 2, 3, 1])
def gaussian_pixel_similarity(a, b, theta):
......@@ -339,12 +384,12 @@ def dilated_cross_pixel_similarity(feature_map, dilation=2, theta=2.0):
[1]: https://arxiv.org/abs/2012.02310
Args:
feature_map: A float tensor of shape [height, width, channels]
feature_map: A float tensor of shape [batch_size, height, width, channels]
dilation: int, the dilation factor.
theta: The denominator while taking difference inside the gaussian.
Returns:
dilated_similarity: A tensor of shape [8, height, width]
dilated_similarity: A tensor of shape [8, batch_size, height, width]
"""
neighbors = generate_2d_neighbors(feature_map, dilation)
feature_map = feature_map[tf.newaxis]
......@@ -358,21 +403,26 @@ def dilated_cross_same_mask_label(instance_masks, dilation=2):
[1]: https://arxiv.org/abs/2012.02310
Args:
instance_masks: A float tensor of shape [num_instances, height, width]
instance_masks: A float tensor of shape [batch_size, num_instances,
height, width]
dilation: int, the dilation factor.
Returns:
dilated_same_label: A tensor of shape [8, num_instances, height, width]
dilated_same_label: A tensor of shape [8, batch_size, num_instances,
height, width]
"""
instance_masks = tf.transpose(instance_masks, (1, 2, 0))
# instance_masks: [batch_size, height, width, num_instances]
instance_masks = tf.transpose(instance_masks, (0, 2, 3, 1))
# neighbors: [8, batch_size, height, width, num_instances]
neighbors = generate_2d_neighbors(instance_masks, dilation)
# instance_masks = [1, batch_size, height, width, num_instances]
instance_masks = instance_masks[tf.newaxis]
same_mask_prob = ((instance_masks * neighbors) +
((1 - instance_masks) * (1 - neighbors)))
return tf.transpose(same_mask_prob, (0, 3, 1, 2))
return tf.transpose(same_mask_prob, (0, 1, 4, 2, 3))
def _per_pixel_single_conv(input_tensor, params, channels):
......@@ -722,6 +772,10 @@ class MaskHeadNetwork(tf.keras.layers.Layer):
return tf.squeeze(out, axis=-1)
def _batch_gt_list(gt_list):
return tf.stack(gt_list, axis=0)
def deepmac_proto_to_params(deepmac_config):
"""Convert proto to named tuple."""
......@@ -765,7 +819,8 @@ def deepmac_proto_to_params(deepmac_config):
color_consistency_warmup_steps=
deepmac_config.color_consistency_warmup_steps,
color_consistency_warmup_start=
deepmac_config.color_consistency_warmup_start
deepmac_config.color_consistency_warmup_start,
use_only_last_stage=deepmac_config.use_only_last_stage
)
......@@ -808,8 +863,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
f'pixel_embedding_dim({pixel_embedding_dim}) '
f'must be same as dim({dim}).')
loss = self._deepmac_params.classification_loss
super(DeepMACMetaArch, self).__init__(
is_training=is_training, add_summaries=add_summaries,
num_classes=num_classes, feature_extractor=feature_extractor,
......@@ -847,60 +900,62 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Get the input to the mask network, given bounding boxes.
Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in
normalized coordinates.
pixel_embedding: A [height, width, embedding_size] float tensor
containing spatial pixel embeddings.
boxes: A [batch_size, num_instances, 4] float tensor containing bounding
boxes in normalized coordinates.
pixel_embedding: A [batch_size, height, width, embedding_size] float
tensor containing spatial pixel embeddings.
Returns:
embedding: A [num_instances, mask_height, mask_width, embedding_size + 2]
float tensor containing the inputs to the mask network. For each
bounding box, we concatenate the normalized box coordinates to the
cropped pixel embeddings. If predict_full_resolution_masks is set,
mask_height and mask_width are the same as height and width of
pixel_embedding. If not, mask_height and mask_width are the same as
mask_size.
embedding: A [batch_size, num_instances, mask_height, mask_width,
embedding_size + 2] float tensor containing the inputs to the mask
network. For each bounding box, we concatenate the normalized box
coordinates to the cropped pixel embeddings. If
predict_full_resolution_masks is set, mask_height and mask_width are
the same as height and width of pixel_embedding. If not, mask_height
and mask_width are the same as mask_size.
"""
num_instances = tf.shape(boxes)[0]
batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks:
num_instances = tf.shape(boxes)[0]
pixel_embedding = pixel_embedding[tf.newaxis, :, :, :]
num_instances = tf.shape(boxes)[1]
pixel_embedding = pixel_embedding[:, tf.newaxis, :, :, :]
pixel_embeddings_processed = tf.tile(pixel_embedding,
[num_instances, 1, 1, 1])
[1, num_instances, 1, 1, 1])
image_shape = tf.shape(pixel_embeddings_processed)
image_height, image_width = image_shape[1], image_shape[2]
image_height, image_width = image_shape[2], image_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(0.0, 1.0, image_height),
tf.linspace(0.0, 1.0, image_width),
indexing='ij')
blist = box_list.BoxList(boxes)
ycenter, xcenter, _, _ = blist.get_center_coordinates_and_sizes()
y_grid = y_grid[tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, :, :]
ycenter = (boxes[:, :, 0] + boxes[:, :, 2]) / 2.0
xcenter = (boxes[:, :, 1] + boxes[:, :, 3]) / 2.0
y_grid = y_grid[tf.newaxis, tf.newaxis, :, :]
x_grid = x_grid[tf.newaxis, tf.newaxis, :, :]
y_grid -= ycenter[:, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=3)
y_grid -= ycenter[:, :, tf.newaxis, tf.newaxis]
x_grid -= xcenter[:, :, tf.newaxis, tf.newaxis]
coords = tf.stack([y_grid, x_grid], axis=4)
else:
# TODO(vighneshb) Explore multilevel_roi_align and align_corners=False.
pixel_embeddings_processed = crop_and_resize_feature_map(
pixel_embedding, boxes, mask_size)
embeddings = spatial_transform_ops.matmul_crop_and_resize(
pixel_embedding, boxes, [mask_size, mask_size])
pixel_embeddings_processed = embeddings
mask_shape = tf.shape(pixel_embeddings_processed)
mask_height, mask_width = mask_shape[1], mask_shape[2]
mask_height, mask_width = mask_shape[2], mask_shape[3]
y_grid, x_grid = tf.meshgrid(tf.linspace(-1.0, 1.0, mask_height),
tf.linspace(-1.0, 1.0, mask_width),
indexing='ij')
coords = tf.stack([y_grid, x_grid], axis=2)
coords = coords[tf.newaxis, :, :, :]
coords = tf.tile(coords, [num_instances, 1, 1, 1])
coords = coords[tf.newaxis, tf.newaxis, :, :, :]
coords = tf.tile(coords, [batch_size, num_instances, 1, 1, 1])
if self._deepmac_params.use_xy:
return tf.concat([coords, pixel_embeddings_processed], axis=3)
return tf.concat([coords, pixel_embeddings_processed], axis=4)
else:
return pixel_embeddings_processed
......@@ -908,43 +963,94 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
"""Return the instance embeddings from bounding box centers.
Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The
coordinates are in normalized input space.
instance_embedding: A [height, width, embedding_size] float tensor
containing the instance embeddings.
boxes: A [batch_size, num_instances, 4] float tensor holding bounding
boxes. The coordinates are in normalized input space.
instance_embedding: A [batch_size, height, width, embedding_size] float
tensor containing the instance embeddings.
Returns:
instance_embeddings: A [num_instances, embedding_size] shaped float tensor
containing the center embedding for each instance.
instance_embeddings: A [batch_size, num_instances, embedding_size]
shaped float tensor containing the center embedding for each instance.
"""
blist = box_list.BoxList(boxes)
output_height = tf.shape(instance_embedding)[0]
output_width = tf.shape(instance_embedding)[1]
blist_output = box_list_ops.to_absolute_coordinates(
blist, output_height, output_width, check_range=False)
(y_center_output, x_center_output,
_, _) = blist_output.get_center_coordinates_and_sizes()
center_coords_output = tf.stack([y_center_output, x_center_output], axis=1)
output_height = tf.cast(tf.shape(instance_embedding)[1], tf.float32)
output_width = tf.cast(tf.shape(instance_embedding)[2], tf.float32)
ymin = boxes[:, :, 0]
xmin = boxes[:, :, 1]
ymax = boxes[:, :, 2]
xmax = boxes[:, :, 3]
y_center_output = (ymin + ymax) * output_height / 2.0
x_center_output = (xmin + xmax) * output_width / 2.0
center_coords_output = tf.stack([y_center_output, x_center_output], axis=2)
center_coords_output_int = tf.cast(center_coords_output, tf.int32)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int)
center_latents = tf.gather_nd(instance_embedding, center_coords_output_int,
batch_dims=1)
return center_latents
def predict(self, preprocessed_inputs, other_inputs):
prediction_dict = super(DeepMACMetaArch, self).predict(
preprocessed_inputs, other_inputs)
mask_logits = self._predict_mask_logits_from_gt_boxes(prediction_dict)
prediction_dict[MASK_LOGITS_GT_BOXES] = mask_logits
return prediction_dict
def _predict_mask_logits_from_embeddings(
self, pixel_embedding, instance_embedding, boxes):
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
mask_input, batch_size, num_instances = flatten_first2_dims(mask_input)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
instance_embeddings, _, _ = flatten_first2_dims(instance_embeddings)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_logits = unpack_first2_dims(
mask_logits, batch_size, num_instances)
return mask_logits
def _predict_mask_logits_from_gt_boxes(self, prediction_dict):
mask_logits_list = []
boxes = _batch_gt_list(self.groundtruth_lists(fields.BoxListFields.boxes))
instance_embedding_list = prediction_dict[INSTANCE_EMBEDDING]
pixel_embedding_list = prediction_dict[PIXEL_EMBEDDING]
if self._deepmac_params.use_only_last_stage:
instance_embedding_list = [instance_embedding_list[-1]]
pixel_embedding_list = [pixel_embedding_list[-1]]
for (instance_embedding, pixel_embedding) in zip(instance_embedding_list,
pixel_embedding_list):
mask_logits_list.append(
self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes))
return mask_logits_list
def _get_groundtruth_mask_output(self, boxes, masks):
"""Get the expected mask output for each box.
Args:
boxes: A [num_instances, 4] float tensor containing bounding boxes in
normalized coordinates.
masks: A [num_instances, height, width] float tensor containing binary
ground truth masks.
boxes: A [batch_size, num_instances, 4] float tensor containing bounding
boxes in normalized coordinates.
masks: A [batch_size, num_instances, height, width] float tensor
containing binary ground truth masks.
Returns:
masks: If predict_full_resolution_masks is set, masks are not resized
and the size of this tensor is [num_instances, input_height, input_width].
Otherwise, returns a tensor of size [num_instances, mask_size, mask_size].
and the size of this tensor is [batch_size, num_instances,
input_height, input_width]. Otherwise, returns a tensor of size
[batch_size, num_instances, mask_size, mask_size].
"""
mask_size = self._deepmac_params.mask_size
if self._deepmac_params.predict_full_resolution_masks:
return masks
......@@ -957,9 +1063,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return cropped_masks
def _resize_logits_like_gt(self, logits, gt):
height, width = tf.shape(gt)[1], tf.shape(gt)[2]
height, width = tf.shape(gt)[2], tf.shape(gt)[3]
return resize_instance_masks(logits, (height, width))
def _aggregate_classification_loss(self, loss, gt, pred, method):
......@@ -1016,54 +1120,59 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else:
raise ValueError('Unknown loss aggregation - {}'.format(method))
def _compute_per_instance_mask_prediction_loss(
def _compute_mask_prediction_loss(
self, boxes, mask_logits, mask_gt):
"""Compute the per-instance mask loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks
mask_gt: The groundtruth mask.
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
mask_logits: A [batch_suze, num_instances, height, width] float tensor of
predicted masks
mask_gt: The groundtruth mask of same shape as mask_logits.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance.
"""
num_instances = tf.shape(boxes)[0]
batch_size, num_instances = tf.shape(boxes)[0], tf.shape(boxes)[1]
mask_logits = self._resize_logits_like_gt(mask_logits, mask_gt)
mask_logits = tf.reshape(mask_logits, [num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [num_instances, -1, 1])
mask_logits = tf.reshape(mask_logits, [batch_size * num_instances, -1, 1])
mask_gt = tf.reshape(mask_gt, [batch_size * num_instances, -1, 1])
loss = self._deepmac_params.classification_loss(
prediction_tensor=mask_logits,
target_tensor=mask_gt,
weights=tf.ones_like(mask_logits))
return self._aggregate_classification_loss(
loss = self._aggregate_classification_loss(
loss, mask_gt, mask_logits, 'normalize_auto')
return tf.reshape(loss, [batch_size, num_instances])
def _compute_per_instance_box_consistency_loss(
def _compute_box_consistency_loss(
self, boxes_gt, boxes_for_crop, mask_logits):
"""Compute the per-instance box consistency loss.
Args:
boxes_gt: A [num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [num_instances, 4] float tensor of augmented boxes,
to be used when using crop-and-resize based mask head.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
boxes_gt: A [batch_size, num_instances, 4] float tensor of GT boxes.
boxes_for_crop: A [batch_size, num_instances, 4] float tensor of
augmented boxes, to be used when using crop-and-resize based mask head.
mask_logits: A [batch_size, num_instances, height, width]
float tensor of predicted masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
loss: A [batch_size, num_instances] shaped tensor with the loss for
each instance in the batch.
"""
height, width = tf.shape(mask_logits)[1], tf.shape(mask_logits)[2]
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, tf.newaxis]
shape = tf.shape(mask_logits)
batch_size, num_instances, height, width = (
shape[0], shape[1], shape[2], shape[3])
filled_boxes = fill_boxes(boxes_gt, height, width)[:, :, :, :, tf.newaxis]
mask_logits = mask_logits[:, :, :, :, tf.newaxis]
if self._deepmac_params.predict_full_resolution_masks:
gt_crop = filled_boxes[:, :, :, 0]
pred_crop = mask_logits[:, :, :, 0]
gt_crop = filled_boxes[:, :, :, :, 0]
pred_crop = mask_logits[:, :, :, :, 0]
else:
gt_crop = crop_and_resize_instance_masks(
filled_boxes, boxes_for_crop, self._deepmac_params.mask_size)
......@@ -1071,7 +1180,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_logits, boxes_for_crop, self._deepmac_params.mask_size)
loss = 0.0
for axis in [1, 2]:
for axis in [2, 3]:
if self._deepmac_params.box_consistency_tightness:
pred_max_raw = tf.reduce_max(pred_crop, axis=axis)
......@@ -1083,44 +1192,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
else:
pred_max = tf.reduce_max(pred_crop, axis=axis)
pred_max = pred_max[:, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, tf.newaxis]
pred_max = pred_max[:, :, :, tf.newaxis]
gt_max = tf.reduce_max(gt_crop, axis=axis)[:, :, :, tf.newaxis]
flat_pred, batch_size, num_instances = flatten_first2_dims(pred_max)
flat_gt, _, _ = flatten_first2_dims(gt_max)
# We use flat tensors while calling loss functions because we
# want the loss per-instance to later multiply with the per-instance
# weight. Flattening the first 2 dims allows us to represent each instance
# in each batch as though they were samples in a larger batch.
raw_loss = self._deepmac_params.classification_loss(
prediction_tensor=pred_max,
target_tensor=gt_max,
weights=tf.ones_like(pred_max))
prediction_tensor=flat_pred,
target_tensor=flat_gt,
weights=tf.ones_like(flat_pred))
loss += self._aggregate_classification_loss(
raw_loss, gt_max, pred_max,
agg_loss = self._aggregate_classification_loss(
raw_loss, flat_gt, flat_pred,
self._deepmac_params.box_consistency_loss_normalize)
loss += unpack_first2_dims(agg_loss, batch_size, num_instances)
return loss
def _compute_per_instance_color_consistency_loss(
def _compute_color_consistency_loss(
self, boxes, preprocessed_image, mask_logits):
"""Compute the per-instance color consistency loss.
Args:
boxes: A [num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [height, width, 3] float tensor containing the
preprocessed image.
mask_logits: A [num_instances, height, width] float tensor of predicted
masks.
boxes: A [batch_size, num_instances, 4] float tensor of GT boxes.
preprocessed_image: A [batch_size, height, width, 3]
float tensor containing the preprocessed image.
mask_logits: A [batch_size, num_instances, height, width] float tensor of
predicted masks.
Returns:
loss: A [num_instances] shaped tensor with the loss for each instance.
loss: A [batch_size, num_instances] shaped tensor with the loss for each
instance fpr each sample in the batch.
"""
if not self._deepmac_params.predict_full_resolution_masks:
logging.info('Color consistency is not implemented with RoIAlign '
', i.e, fixed sized masks. Returning 0 loss.')
return tf.zeros(tf.shape(boxes)[0])
return tf.zeros(tf.shape(boxes)[:2])
dilation = self._deepmac_params.color_consistency_dilation
height, width = (tf.shape(preprocessed_image)[0],
tf.shape(preprocessed_image)[1])
height, width = (tf.shape(preprocessed_image)[1],
tf.shape(preprocessed_image)[2])
color_similarity = dilated_cross_pixel_similarity(
preprocessed_image, dilation=dilation, theta=2.0)
mask_probs = tf.nn.sigmoid(mask_logits)
......@@ -1132,20 +1250,20 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
color_similarity_mask = (
color_similarity > self._deepmac_params.color_consistency_threshold)
color_similarity_mask = tf.cast(
color_similarity_mask[:, tf.newaxis, :, :], tf.float32)
color_similarity_mask[:, :, tf.newaxis, :, :], tf.float32)
per_pixel_loss = -(color_similarity_mask *
tf.math.log(same_mask_label_probability))
# TODO(vighneshb) explore if shrinking the box by 1px helps.
box_mask = fill_boxes(boxes, height, width)
box_mask_expanded = box_mask[tf.newaxis, :, :, :]
box_mask_expanded = box_mask[tf.newaxis]
per_pixel_loss = per_pixel_loss * box_mask_expanded
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 2, 3])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[1, 2]))
loss = tf.reduce_sum(per_pixel_loss, axis=[0, 3, 4])
num_box_pixels = tf.maximum(1.0, tf.reduce_sum(box_mask, axis=[2, 3]))
loss = loss / num_box_pixels
if ((self._deepmac_params.color_consistency_warmup_steps > 0) and
self._is_training):
tf.keras.backend.learning_phase()):
training_step = tf.cast(self.training_step, tf.float32)
warmup_steps = tf.cast(
self._deepmac_params.color_consistency_warmup_steps, tf.float32)
......@@ -1157,56 +1275,53 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
return loss
def _compute_per_instance_deepmac_losses(
self, boxes, masks, instance_embedding, pixel_embedding,
image):
def _compute_deepmac_losses(
self, boxes, masks_logits, masks_gt, image):
"""Returns the mask loss per instance.
Args:
boxes: A [num_instances, 4] float tensor holding bounding boxes. The
coordinates are in normalized input space.
masks: A [num_instances, input_height, input_width] float tensor
containing the instance masks.
instance_embedding: A [output_height, output_width, embedding_size]
float tensor containing the instance embeddings.
pixel_embedding: optional [output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embeddings.
image: [output_height, output_width, channels] float tensor
boxes: A [batch_size, num_instances, 4] float tensor holding bounding
boxes. The coordinates are in normalized input space.
masks_logits: A [batch_size, num_instances, input_height, input_width]
float tensor containing the instance mask predictions in their logit
form.
masks_gt: A [batch_size, num_instances, input_height, input_width] float
tensor containing the groundtruth masks.
image: [batch_size, output_height, output_width, channels] float tensor
denoting the input image.
Returns:
mask_prediction_loss: A [num_instances] shaped float tensor containing the
mask loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the box consistency loss for each instance.
box_consistency_loss: A [num_instances] shaped float tensor containing
the color consistency loss.
mask_prediction_loss: A [batch_size, num_instances] shaped float tensor
containing the mask loss for each instance in the batch.
box_consistency_loss: A [batch_size, num_instances] shaped float tensor
containing the box consistency loss for each instance in the batch.
box_consistency_loss: A [batch_size, num_instances] shaped float tensor
containing the color consistency loss in the batch.
"""
if tf.keras.backend.learning_phase():
boxes_for_crop = preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
boxes = tf.stop_gradient(boxes)
def jitter_func(boxes):
return preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
boxes_for_crop = tf.map_fn(jitter_func,
boxes, parallel_iterations=128)
else:
boxes_for_crop = boxes
mask_input = self._get_mask_head_input(
boxes_for_crop, pixel_embedding)
instance_embeddings = self._get_instance_embeddings(
boxes_for_crop, instance_embedding)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_gt = self._get_groundtruth_mask_output(boxes_for_crop, masks)
mask_gt = self._get_groundtruth_mask_output(
boxes_for_crop, masks_gt)
mask_prediction_loss = self._compute_per_instance_mask_prediction_loss(
boxes_for_crop, mask_logits, mask_gt)
mask_prediction_loss = self._compute_mask_prediction_loss(
boxes_for_crop, masks_logits, mask_gt)
box_consistency_loss = self._compute_per_instance_box_consistency_loss(
boxes, boxes_for_crop, mask_logits)
box_consistency_loss = self._compute_box_consistency_loss(
boxes, boxes_for_crop, masks_logits)
color_consistency_loss = self._compute_per_instance_color_consistency_loss(
boxes, image, mask_logits)
color_consistency_loss = self._compute_color_consistency_loss(
boxes, image, masks_logits)
return {
DEEP_MASK_ESTIMATION: mask_prediction_loss,
......@@ -1224,7 +1339,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
' consistency loss is not supported in TF1.'))
return tfio.experimental.color.rgb_to_lab(raw_image)
def _compute_instance_masks_loss(self, prediction_dict):
def _compute_masks_loss(self, prediction_dict):
"""Computes the mask loss.
Args:
......@@ -1236,10 +1351,6 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
Returns:
loss_dict: A dict mapping string (loss names) to scalar floats.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_masks_list = self.groundtruth_lists(fields.BoxListFields.masks)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
allowed_masked_classes_ids = (
self._deepmac_params.allowed_masked_classes_ids)
......@@ -1248,8 +1359,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
for loss_name in MASK_LOSSES:
loss_dict[loss_name] = 0.0
prediction_shape = tf.shape(prediction_dict[INSTANCE_EMBEDDING][0])
height, width = prediction_shape[1], prediction_shape[2]
prediction_shape = tf.shape(prediction_dict[MASK_LOGITS_GT_BOXES][0])
height, width = prediction_shape[2], prediction_shape[3]
preprocessed_image = tf.image.resize(
prediction_dict['preprocessed_inputs'], (height, width))
......@@ -1258,42 +1369,46 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
# TODO(vighneshb) See if we can save memory by only using the final
# prediction
# Iterate over multiple preidctions by backbone (for hourglass length=2)
for instance_pred, pixel_pred in zip(
prediction_dict[INSTANCE_EMBEDDING],
prediction_dict[PIXEL_EMBEDDING]):
# Iterate over samples in batch
# TODO(vighneshb) find out how autograph is handling this. Converting
# to a single op may give speed improvements
for i, (boxes, weights, classes, masks) in enumerate(
zip(gt_boxes_list, gt_weights_list, gt_classes_list, gt_masks_list)):
# TODO(vighneshb) Add sub-sampling back if required.
classes, valid_mask_weights, masks = filter_masked_classes(
allowed_masked_classes_ids, classes, weights, masks)
sample_loss_dict = self._compute_per_instance_deepmac_losses(
boxes, masks, instance_pred[i], pixel_pred[i], image[i])
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= weights
num_instances = tf.maximum(tf.reduce_sum(weights), 1.0)
num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) /
num_instances_allowed)
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) /
num_instances)
batch_size = len(gt_boxes_list)
num_predictions = len(prediction_dict[INSTANCE_EMBEDDING])
return dict((key, loss / float(batch_size * num_predictions))
gt_boxes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.boxes))
gt_weights = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.weights))
gt_masks = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.masks))
gt_classes = _batch_gt_list(
self.groundtruth_lists(fields.BoxListFields.classes))
mask_logits_list = prediction_dict[MASK_LOGITS_GT_BOXES]
for mask_logits in mask_logits_list:
# TODO(vighneshb) Add sub-sampling back if required.
_, valid_mask_weights, gt_masks = filter_masked_classes(
allowed_masked_classes_ids, gt_classes,
gt_weights, gt_masks)
sample_loss_dict = self._compute_deepmac_losses(
gt_boxes, mask_logits, gt_masks, image)
sample_loss_dict[DEEP_MASK_ESTIMATION] *= valid_mask_weights
for loss_name in WEAK_LOSSES:
sample_loss_dict[loss_name] *= gt_weights
num_instances = tf.maximum(tf.reduce_sum(gt_weights), 1.0)
num_instances_allowed = tf.maximum(
tf.reduce_sum(valid_mask_weights), 1.0)
loss_dict[DEEP_MASK_ESTIMATION] += (
tf.reduce_sum(sample_loss_dict[DEEP_MASK_ESTIMATION]) /
num_instances_allowed)
for loss_name in WEAK_LOSSES:
loss_dict[loss_name] += (tf.reduce_sum(sample_loss_dict[loss_name]) /
num_instances)
num_predictions = len(mask_logits_list)
return dict((key, loss / float(num_predictions))
for key, loss in loss_dict.items())
def loss(self, prediction_dict, true_image_shapes, scope=None):
......@@ -1302,7 +1417,7 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
prediction_dict, true_image_shapes, scope)
if self._deepmac_params is not None:
mask_loss_dict = self._compute_instance_masks_loss(
mask_loss_dict = self._compute_masks_loss(
prediction_dict=prediction_dict)
for loss_name in MASK_LOSSES:
......@@ -1363,50 +1478,18 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
mask_size] containing binary per-box instance masks.
"""
def process(elems):
boxes, instance_embedding, pixel_embedding = elems
return self._postprocess_sample(boxes, instance_embedding,
pixel_embedding)
max_instances = self._center_params.max_box_predictions
return tf.map_fn(process, [boxes_output_stride, instance_embedding,
pixel_embedding],
dtype=tf.float32, parallel_iterations=max_instances)
def _postprocess_sample(self, boxes_output_stride,
instance_embedding, pixel_embedding):
"""Post process masks for a single sample.
Args:
boxes_output_stride: A [num_instances, 4] float tensor containing
bounding boxes in the absolute output space.
instance_embedding: A [output_height, output_width, embedding_size]
float tensor containing instance embeddings.
pixel_embedding: A [batch_size, output_height, output_width,
pixel_embedding_size] float tensor containing the per-pixel embedding.
Returns:
masks: A float tensor of size [num_instances, mask_height, mask_width]
containing binary per-box instance masks. If
predict_full_resolution_masks is set, the masks will be resized to
postprocess_crop_size. Otherwise, mask_height=mask_width=mask_size
"""
height, width = (tf.shape(instance_embedding)[0],
tf.shape(instance_embedding)[1])
height, width = (tf.shape(instance_embedding)[1],
tf.shape(instance_embedding)[2])
height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
blist = box_list.BoxList(boxes_output_stride)
blist = box_list_ops.to_normalized_coordinates(
blist, height, width, check_range=False)
boxes = blist.get()
mask_input = self._get_mask_head_input(boxes, pixel_embedding)
instance_embeddings = self._get_instance_embeddings(
boxes, instance_embedding)
ymin, xmin, ymax, xmax = tf.unstack(boxes_output_stride, axis=2)
ymin /= height
ymax /= height
xmin /= width
xmax /= width
boxes = tf.stack([ymin, xmin, ymax, xmax], axis=2)
mask_logits = self._mask_net(
instance_embeddings, mask_input,
training=tf.keras.backend.learning_phase())
mask_logits = self._predict_mask_logits_from_embeddings(
pixel_embedding, instance_embedding, boxes)
# TODO(vighneshb) Explore sweeping mask thresholds.
......@@ -1416,7 +1499,8 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
height *= self._stride
width *= self._stride
mask_logits = resize_instance_masks(mask_logits, (height, width))
mask_logits = crop_masks_within_boxes(
mask_logits = crop_and_resize_instance_masks(
mask_logits, boxes, self._deepmac_params.postprocess_crop_size)
masks_prob = tf.nn.sigmoid(mask_logits)
......
......@@ -44,6 +44,7 @@ DEEPMAC_PROTO_TEXT = """
box_consistency_loss_normalize: NORMALIZE_AUTO
color_consistency_warmup_steps: 20
color_consistency_warmup_start: 10
use_only_last_stage: false
"""
......@@ -117,10 +118,11 @@ def build_meta_arch(**override_params):
mask_size=16,
postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random',
roi_jitter_mode='default',
color_consistency_dilation=2,
color_consistency_warmup_steps=0,
color_consistency_warmup_start=0)
color_consistency_warmup_start=0,
use_only_last_stage=True)
params.update(override_params)
......@@ -185,6 +187,7 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(params, deepmac_meta_arch.DeepMACParams)
self.assertEqual(params.dim, 153)
self.assertEqual(params.box_consistency_loss_normalize, 'normalize_auto')
self.assertFalse(params.use_only_last_stage)
def test_subsample_trivial(self):
"""Test subsampling masks."""
......@@ -201,32 +204,71 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(result[2], boxes)
self.assertAllClose(result[3], masks)
def test_filter_masked_classes(self):
classes = np.zeros((2, 3, 5), dtype=np.float32)
classes[0, 0] = [1.0, 0.0, 0.0, 0.0, 0.0]
classes[0, 1] = [0.0, 1.0, 0.0, 0.0, 0.0]
classes[0, 2] = [0.0, 0.0, 1.0, 0.0, 0.0]
classes[1, 0] = [0.0, 0.0, 0.0, 1.0, 0.0]
classes[1, 1] = [0.0, 0.0, 0.0, 0.0, 1.0]
classes[1, 2] = [0.0, 0.0, 0.0, 0.0, 1.0]
classes = tf.constant(classes)
weights = tf.constant([[1.0, 1.0, 1.0], [1.0, 1.0, 0.0]])
masks = tf.ones((2, 3, 32, 32), dtype=tf.float32)
classes, weights, masks = deepmac_meta_arch.filter_masked_classes(
[3, 4], classes, weights, masks)
expected_classes = np.zeros((2, 3, 5))
expected_classes[0, 0] = [0.0, 0.0, 0.0, 0.0, 0.0]
expected_classes[0, 1] = [0.0, 0.0, 0.0, 0.0, 0.0]
expected_classes[0, 2] = [0.0, 0.0, 1.0, 0.0, 0.0]
expected_classes[1, 0] = [0.0, 0.0, 0.0, 1.0, 0.0]
expected_classes[1, 1] = [0.0, 0.0, 0.0, 0.0, 0.0]
expected_classes[1, 2] = [0.0, 0.0, 0.0, 0.0, 0.0]
self.assertAllClose(expected_classes, classes.numpy())
self.assertAllClose(np.array(([0.0, 0.0, 1.0], [1.0, 0.0, 0.0])), weights)
self.assertAllClose(masks[0, 0], np.zeros((32, 32)))
self.assertAllClose(masks[0, 1], np.zeros((32, 32)))
self.assertAllClose(masks[0, 2], np.ones((32, 32)))
self.assertAllClose(masks[1, 0], np.ones((32, 32)))
self.assertAllClose(masks[1, 1], np.zeros((32, 32)))
def test_fill_boxes(self):
boxes = tf.constant([[0., 0., 0.5, 0.5], [0.5, 0.5, 1.0, 1.0]])
boxes = tf.constant([[[0., 0., 0.5, 0.5], [0.5, 0.5, 1.0, 1.0]],
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
filled_boxes = deepmac_meta_arch.fill_boxes(boxes, 32, 32)
expected = np.zeros((2, 32, 32))
expected[0, :17, :17] = 1.0
expected[1, 16:, 16:] = 1.0
expected = np.zeros((2, 2, 32, 32))
expected[0, 0, :17, :17] = 1.0
expected[0, 1, 16:, 16:] = 1.0
expected[1, 0, :, :] = 1.0
filled_boxes = filled_boxes.numpy()
self.assertAllClose(expected[0, 0], filled_boxes[0, 0], rtol=1e-3)
self.assertAllClose(expected[0, 1], filled_boxes[0, 1], rtol=1e-3)
self.assertAllClose(expected[1, 0], filled_boxes[1, 0], rtol=1e-3)
self.assertAllClose(expected, filled_boxes.numpy(), rtol=1e-3)
def test_flatten_and_unpack(self):
t = tf.random.uniform((2, 3, 4, 5, 6))
flatten = tf.function(deepmac_meta_arch.flatten_first2_dims)
unpack = tf.function(deepmac_meta_arch.unpack_first2_dims)
result, d1, d2 = flatten(t)
result = unpack(result, d1, d2)
self.assertAllClose(result.numpy(), t)
def test_crop_and_resize_instance_masks(self):
boxes = tf.zeros((5, 4))
masks = tf.zeros((5, 128, 128))
boxes = tf.zeros((8, 5, 4))
masks = tf.zeros((8, 5, 128, 128))
output = deepmac_meta_arch.crop_and_resize_instance_masks(
masks, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32))
def test_crop_and_resize_feature_map(self):
boxes = tf.zeros((5, 4))
features = tf.zeros((128, 128, 7))
output = deepmac_meta_arch.crop_and_resize_feature_map(
features, boxes, 32)
self.assertEqual(output.shape, (5, 32, 32, 7))
self.assertEqual(output.shape, (8, 5, 32, 32))
def test_embedding_projection_prob_shape(self):
dist = deepmac_meta_arch.embedding_projection(
......@@ -262,73 +304,75 @@ class DeepMACUtilsTest(tf.test.TestCase, parameterized.TestCase):
def test_generate_2d_neighbors_shape(self):
inp = tf.zeros((13, 14, 3))
inp = tf.zeros((5, 13, 14, 3))
out = deepmac_meta_arch.generate_2d_neighbors(inp)
self.assertEqual((8, 13, 14, 3), out.shape)
self.assertEqual((8, 5, 13, 14, 3), out.shape)
def test_generate_2d_neighbors(self):
inp = np.arange(16).reshape(4, 4).astype(np.float32)
inp = tf.stack([inp, inp * 2], axis=2)
inp = tf.reshape(inp, (1, 4, 4, 2))
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=1)
self.assertEqual((8, 4, 4, 2), out.shape)
self.assertEqual((8, 1, 4, 4, 2), out.shape)
for i in range(2):
expected = np.array([0, 1, 2, 4, 6, 8, 9, 10]) * (i + 1)
self.assertAllEqual(out[:, 1, 1, i], expected)
self.assertAllEqual(out[:, 0, 1, 1, i], expected)
expected = np.array([1, 2, 3, 5, 7, 9, 10, 11]) * (i + 1)
self.assertAllEqual(out[:, 1, 2, i], expected)
self.assertAllEqual(out[:, 0, 1, 2, i], expected)
expected = np.array([4, 5, 6, 8, 10, 12, 13, 14]) * (i + 1)
self.assertAllEqual(out[:, 2, 1, i], expected)
self.assertAllEqual(out[:, 0, 2, 1, i], expected)
expected = np.array([5, 6, 7, 9, 11, 13, 14, 15]) * (i + 1)
self.assertAllEqual(out[:, 2, 2, i], expected)
self.assertAllEqual(out[:, 0, 2, 2, i], expected)
def test_generate_2d_neighbors_dilation2(self):
inp = np.arange(16).reshape(4, 4, 1).astype(np.float32)
inp = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
out = deepmac_meta_arch.generate_2d_neighbors(inp, dilation=2)
self.assertEqual((8, 4, 4, 1), out.shape)
self.assertEqual((8, 1, 4, 4, 1), out.shape)
expected = np.array([0, 0, 0, 0, 2, 0, 8, 10])
self.assertAllEqual(out[:, 0, 0, 0], expected)
self.assertAllEqual(out[:, 0, 0, 0, 0], expected)
def test_dilated_similarity_shape(self):
fmap = tf.zeros((32, 32, 9))
fmap = tf.zeros((5, 32, 32, 9))
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap)
self.assertEqual((8, 32, 32), similarity.shape)
self.assertEqual((8, 5, 32, 32), similarity.shape)
def test_dilated_similarity(self):
fmap = np.zeros((5, 5, 2), dtype=np.float32)
fmap = np.zeros((1, 5, 5, 2), dtype=np.float32)
fmap[0, 0, :] = 1.0
fmap[4, 4, :] = 1.0
fmap[0, 0, 0, :] = 1.0
fmap[0, 4, 4, :] = 1.0
similarity = deepmac_meta_arch.dilated_cross_pixel_similarity(
fmap, theta=1.0, dilation=2)
self.assertAlmostEqual(similarity.numpy()[0, 2, 2],
self.assertAlmostEqual(similarity.numpy()[0, 0, 2, 2],
np.exp(-np.sqrt(2)))
def test_dilated_same_instance_mask_shape(self):
instances = tf.zeros((5, 32, 32))
instances = tf.zeros((2, 5, 32, 32))
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances)
self.assertEqual((8, 5, 32, 32), output.shape)
self.assertEqual((8, 2, 5, 32, 32), output.shape)
def test_dilated_same_instance_mask(self):
instances = np.zeros((3, 2, 5, 5), dtype=np.float32)
instances[0, 0, 0, 0] = 1.0
instances[0, 0, 2, 2] = 1.0
instances[0, 0, 4, 4] = 1.0
instances[2, 0, 0, 0] = 1.0
instances[2, 0, 2, 2] = 1.0
instances[2, 0, 4, 4] = 0.0
instances = np.zeros((2, 5, 5), dtype=np.float32)
instances[0, 0, 0] = 1.0
instances[0, 2, 2] = 1.0
instances[0, 4, 4] = 1.0
output = deepmac_meta_arch.dilated_cross_same_mask_label(instances).numpy()
self.assertAllClose(np.ones((8, 5, 5)), output[:, 1, :, :])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 2, 2])
self.assertAllClose(np.ones((8, 2, 5, 5)), output[:, 1, :, :])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 1], output[:, 0, 0, 2, 2])
self.assertAllClose([1, 0, 0, 0, 0, 0, 0, 0], output[:, 2, 0, 2, 2])
def test_per_pixel_single_conv_multiple_instance(self):
......@@ -550,151 +594,184 @@ class DeepMACMaskHeadTest(tf.test.TestCase, parameterized.TestCase):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
# TODO(vighneshb): Add batch_size > 1 tests for loss functions.
def setUp(self): # pylint:disable=g-missing-super-call
self.model = build_meta_arch()
def test_get_mask_head_input(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
boxes = tf.constant([[[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
[[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]]],
dtype=tf.float32)
pixel_embedding = np.zeros((32, 32, 4), dtype=np.float32)
pixel_embedding[:16, :16] = 1.0
pixel_embedding[16:, 16:] = 2.0
pixel_embedding = np.zeros((2, 32, 32, 4), dtype=np.float32)
pixel_embedding[0, :16, :16] = 1.0
pixel_embedding[0, 16:, 16:] = 2.0
pixel_embedding[1, :16, :16] = 3.0
pixel_embedding[1, 16:, 16:] = 4.0
pixel_embedding = tf.constant(pixel_embedding)
mask_inputs = self.model._get_mask_head_input(boxes, pixel_embedding)
self.assertEqual(mask_inputs.shape, (2, 16, 16, 6))
self.assertEqual(mask_inputs.shape, (2, 2, 16, 16, 6))
y_grid, x_grid = tf.meshgrid(np.linspace(-1.0, 1.0, 16),
np.linspace(-1.0, 1.0, 16), indexing='ij')
for i in range(2):
mask_input = mask_inputs[i]
self.assertAllClose(y_grid, mask_input[:, :, 0])
self.assertAllClose(x_grid, mask_input[:, :, 1])
pixel_embedding = mask_input[:, :, 2:]
self.assertAllClose(np.zeros((16, 16, 4)) + i + 1, pixel_embedding)
for i, j in ([0, 0], [0, 1], [1, 0], [1, 1]):
self.assertAllClose(y_grid, mask_inputs[i, j, :, :, 0])
self.assertAllClose(x_grid, mask_inputs[i, j, :, :, 1])
zeros = np.zeros((16, 16, 4))
self.assertAllClose(zeros + 1, mask_inputs[0, 0, :, :, 2:])
self.assertAllClose(zeros + 2, mask_inputs[0, 1, :, :, 2:])
self.assertAllClose(zeros + 3, mask_inputs[1, 0, :, :, 2:])
self.assertAllClose(zeros + 4, mask_inputs[1, 1, :, :, 2:])
def test_get_mask_head_input_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True)
boxes = tf.constant([[0., 0., 1.0, 1.0], [0.0, 0.0, 0.5, 1.0]],
dtype=tf.float32)
boxes = tf.constant([[[0., 0., 1.0, 1.0], [0.0, 0.0, 0.5, 1.0]],
[[0.5, 0.5, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]]])
pixel_embedding_np = np.random.randn(32, 32, 4).astype(np.float32)
pixel_embedding_np = np.random.randn(2, 32, 32, 4).astype(np.float32)
pixel_embedding = tf.constant(pixel_embedding_np)
mask_inputs = model._get_mask_head_input(boxes, pixel_embedding)
self.assertEqual(mask_inputs.shape, (2, 32, 32, 6))
self.assertEqual(mask_inputs.shape, (2, 2, 32, 32, 6))
y_grid, x_grid = tf.meshgrid(np.linspace(.0, 1.0, 32),
np.linspace(.0, 1.0, 32), indexing='ij')
ys = [0.5, 0.25]
xs = [0.5, 0.5]
for i in range(2):
mask_input = mask_inputs[i]
self.assertAllClose(y_grid - ys[i], mask_input[:, :, 0])
self.assertAllClose(x_grid - xs[i], mask_input[:, :, 1])
pixel_embedding = mask_input[:, :, 2:]
self.assertAllClose(pixel_embedding_np, pixel_embedding)
self.assertAllClose(y_grid - 0.5, mask_inputs[0, 0, :, :, 0])
self.assertAllClose(x_grid - 0.5, mask_inputs[0, 0, :, :, 1])
self.assertAllClose(y_grid - 0.25, mask_inputs[0, 1, :, :, 0])
self.assertAllClose(x_grid - 0.5, mask_inputs[0, 1, :, :, 1])
self.assertAllClose(y_grid - 0.75, mask_inputs[1, 0, :, :, 0])
self.assertAllClose(x_grid - 0.75, mask_inputs[1, 0, :, :, 1])
self.assertAllClose(y_grid, mask_inputs[1, 1, :, :, 0])
self.assertAllClose(x_grid, mask_inputs[1, 1, :, :, 1])
def test_get_instance_embeddings(self):
embeddings = np.zeros((32, 32, 2))
embeddings[8, 8] = 1.0
embeddings[24, 16] = 2.0
embeddings = np.zeros((2, 32, 32, 2))
embeddings[0, 8, 8] = 1.0
embeddings[0, 24, 16] = 2.0
embeddings[1, 8, 16] = 3.0
embeddings = tf.constant(embeddings)
boxes = tf.constant([[0., 0., 0.5, 0.5], [0.5, 0.0, 1.0, 1.0]])
boxes = np.zeros((2, 2, 4), dtype=np.float32)
boxes[0, 0] = [0.0, 0.0, 0.5, 0.5]
boxes[0, 1] = [0.5, 0.0, 1.0, 1.0]
boxes[1, 0] = [0.0, 0.0, 0.5, 1.0]
boxes = tf.constant(boxes)
center_embeddings = self.model._get_instance_embeddings(boxes, embeddings)
self.assertAllClose(center_embeddings, [[1.0, 1.0], [2.0, 2.0]])
self.assertAllClose(center_embeddings[0, 0], [1.0, 1.0])
self.assertAllClose(center_embeddings[0, 1], [2.0, 2.0])
self.assertAllClose(center_embeddings[1, 0], [3.0, 3.0])
def test_get_groundtruth_mask_output(self):
boxes = tf.constant([[0., 0., 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]],
dtype=tf.float32)
masks = np.zeros((2, 32, 32), dtype=np.float32)
masks[0, :16, :16] = 0.5
masks[1, 16:, 16:] = 0.1
boxes = np.zeros((2, 2, 4))
masks = np.zeros((2, 2, 32, 32))
boxes[0, 0] = [0.0, 0.0, 0.25, 0.25]
boxes[0, 1] = [0.75, 0.75, 1.0, 1.0]
boxes[1, 0] = [0.0, 0.0, 0.5, 1.0]
masks = np.zeros((2, 2, 32, 32), dtype=np.float32)
masks[0, 0, :16, :16] = 0.5
masks[0, 1, 16:, 16:] = 0.1
masks[1, 0, :17, :] = 0.3
masks = self.model._get_groundtruth_mask_output(boxes, masks)
self.assertEqual(masks.shape, (2, 16, 16))
self.assertEqual(masks.shape, (2, 2, 16, 16))
self.assertAllClose(masks[0], np.zeros((16, 16)) + 0.5)
self.assertAllClose(masks[1], np.zeros((16, 16)) + 0.1)
self.assertAllClose(masks[0, 0], np.zeros((16, 16)) + 0.5)
self.assertAllClose(masks[0, 1], np.zeros((16, 16)) + 0.1)
self.assertAllClose(masks[1, 0], np.zeros((16, 16)) + 0.3)
def test_get_groundtruth_mask_output_crop_resize(self):
def test_get_groundtruth_mask_output_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True)
boxes = tf.constant([[0., 0., 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
dtype=tf.float32)
masks = tf.ones((2, 32, 32))
boxes = tf.zeros((2, 5, 4))
masks = tf.ones((2, 5, 32, 32))
masks = model._get_groundtruth_mask_output(boxes, masks)
self.assertAllClose(masks, np.ones((2, 32, 32)))
self.assertAllClose(masks, np.ones((2, 5, 32, 32)))
def test_per_instance_loss(self):
def test_predict(self):
model = build_meta_arch()
model._mask_net = MockMaskNet()
boxes = tf.constant([[0.0, 0.0, 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]])
masks = np.zeros((2, 32, 32), dtype=np.float32)
masks[0, :16, :16] = 1.0
masks[1, 16:, 16:] = 1.0
masks = tf.constant(masks)
tf.keras.backend.set_learning_phase(True)
self.model.provide_groundtruth(
groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)],
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)],
groundtruth_weights_list=[tf.ones(5)],
groundtruth_masks_list=[tf.ones((5, 32, 32))])
prediction = self.model.predict(tf.zeros((1, 32, 32, 3)), None)
self.assertEqual(prediction['MASK_LOGITS_GT_BOXES'][0].shape,
(1, 5, 16, 16))
def test_loss(self):
loss_dict = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((16, 16, 3)))
model = build_meta_arch()
boxes = tf.constant([[[0.0, 0.0, 0.25, 0.25], [0.75, 0.75, 1.0, 1.0]]])
masks = np.zeros((1, 2, 32, 32), dtype=np.float32)
masks[0, 0, :16, :16] = 1.0
masks[0, 1, 16:, 16:] = 1.0
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 16, 16, 3)))
self.assertAllClose(
loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
np.zeros((1, 2)) - tf.math.log(tf.nn.sigmoid(0.9)))
def test_per_instance_loss_no_crop_resize(self):
def test_loss_no_crop_resize(self):
model = build_meta_arch(predict_full_resolution_masks=True)
model._mask_net = MockMaskNet()
boxes = tf.constant([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]])
masks = np.ones((2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
masks = tf.ones((1, 2, 128, 128), dtype=tf.float32)
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
loss_dict = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((32, 32, 3)))
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 32, 32, 3)))
self.assertAllClose(
loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
np.zeros(2) - tf.math.log(tf.nn.sigmoid(0.9)))
np.zeros((1, 2)) - tf.math.log(tf.nn.sigmoid(0.9)))
def test_per_instance_loss_no_crop_resize_dice(self):
def test_loss_no_crop_resize_dice(self):
model = build_meta_arch(predict_full_resolution_masks=True,
use_dice_loss=True)
model._mask_net = MockMaskNet()
boxes = tf.constant([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]])
masks = np.ones((2, 128, 128), dtype=np.float32)
boxes = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
masks = np.ones((1, 2, 128, 128), dtype=np.float32)
masks = tf.constant(masks)
masks_pred = tf.fill((1, 2, 32, 32), 0.9)
loss_dict = model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((32, 32, 3)))
loss_dict = model._compute_deepmac_losses(
boxes, masks_pred, masks, tf.zeros((1, 32, 32, 3)))
pred = tf.nn.sigmoid(0.9)
expected = (1.0 - ((2.0 * pred) / (1.0 + pred)))
self.assertAllClose(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION],
[expected, expected], rtol=1e-3)
[[expected, expected]], rtol=1e-3)
def test_empty_masks(self):
boxes = tf.zeros([0, 4])
masks = tf.zeros([0, 128, 128])
loss_dict = self.model._compute_per_instance_deepmac_losses(
boxes, masks, tf.zeros((32, 32, 2)), tf.zeros((32, 32, 2)),
tf.zeros((16, 16, 3)))
boxes = tf.zeros([1, 0, 4])
masks = tf.zeros([1, 0, 128, 128])
loss_dict = self.model._compute_deepmac_losses(
boxes, masks, masks,
tf.zeros((1, 16, 16, 3)))
self.assertEqual(loss_dict[deepmac_meta_arch.DEEP_MASK_ESTIMATION].shape,
(0,))
(1, 0))
def test_postprocess(self):
model = build_meta_arch()
model._mask_net = MockMaskNet()
boxes = np.zeros((2, 3, 4), dtype=np.float32)
......@@ -708,7 +785,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(masks, prob * np.ones((2, 3, 16, 16)))
def test_postprocess_emb_proj(self):
model = build_meta_arch(network_type='embedding_projection',
use_instance_embedding=False,
use_xy=False, pixel_embedding_dim=8,
......@@ -724,7 +800,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(masks.shape, (2, 3, 16, 16))
def test_postprocess_emb_proj_fullres(self):
model = build_meta_arch(network_type='embedding_projection',
predict_full_resolution_masks=True,
use_instance_embedding=False,
......@@ -751,17 +826,6 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
prob = tf.nn.sigmoid(0.9).numpy()
self.assertAllClose(masks, prob * np.ones((2, 3, 128, 128)))
def test_crop_masks_within_boxes(self):
masks = np.zeros((2, 32, 32))
masks[0, :16, :16] = 1.0
masks[1, 16:, 16:] = 1.0
boxes = tf.constant([[0.0, 0.0, 15.0 / 32, 15.0 / 32],
[0.5, 0.5, 1.0, 1]])
masks = deepmac_meta_arch.crop_masks_within_boxes(
masks, boxes, 128)
masks = (masks.numpy() > 0.0).astype(np.float32)
self.assertAlmostEqual(masks.sum(), 2 * 128 * 128)
def test_transform_boxes_to_feature_coordinates(self):
batch_size = 2
model = build_meta_arch()
......@@ -816,13 +880,13 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def test_box_consistency_loss(self):
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]])
boxes_gt = tf.constant([[[0., 0., 0.49, 1.0]]])
boxes_jittered = tf.constant([[[0.0, 0.0, 1.0, 1.0]]])
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32)
mask_prediction[0, :24, :24] = 1.0
mask_prediction = np.zeros((1, 1, 32, 32)).astype(np.float32)
mask_prediction[0, 0, :24, :24] = 1.0
loss = self.model._compute_per_instance_box_consistency_loss(
loss = self.model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
......@@ -834,39 +898,39 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
yloss_mean = tf.reduce_mean(yloss)
xloss_mean = tf.reduce_mean(xloss)
self.assertAllClose(loss, [yloss_mean + xloss_mean])
self.assertAllClose(loss[0], [yloss_mean + xloss_mean])
def test_box_consistency_loss_with_tightness(self):
boxes_gt = tf.constant([[0., 0., 0.49, 0.49]])
boxes_gt = tf.constant([[[0., 0., 0.49, 0.49]]])
boxes_jittered = None
mask_prediction = np.zeros((1, 8, 8)).astype(np.float32) - 1e10
mask_prediction[0, :4, :4] = 1e10
mask_prediction = np.zeros((1, 1, 8, 8)).astype(np.float32) - 1e10
mask_prediction[0, 0, :4, :4] = 1e10
model = build_meta_arch(box_consistency_tightness=True,
predict_full_resolution_masks=True)
loss = model._compute_per_instance_box_consistency_loss(
loss = model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAllClose(loss, [0.0])
self.assertAllClose(loss[0], [0.0])
def test_box_consistency_loss_gt_count(self):
boxes_gt = tf.constant([
boxes_gt = tf.constant([[
[0., 0., 1.0, 1.0],
[0., 0., 0.49, 0.49]])
[0., 0., 0.49, 0.49]]])
boxes_jittered = None
mask_prediction = np.zeros((2, 32, 32)).astype(np.float32)
mask_prediction[0, :16, :16] = 1.0
mask_prediction[1, :8, :8] = 1.0
mask_prediction = np.zeros((1, 2, 32, 32)).astype(np.float32)
mask_prediction[0, 0, :16, :16] = 1.0
mask_prediction[0, 1, :8, :8] = 1.0
model = build_meta_arch(
box_consistency_loss_normalize='normalize_groundtruth_count',
predict_full_resolution_masks=True)
loss_func = tf.function(
model._compute_per_instance_box_consistency_loss)
loss_func = (
model._compute_box_consistency_loss)
loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
......@@ -877,7 +941,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
xloss = yloss
xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[0], yloss_mean + xloss_mean)
self.assertAllClose(loss[0, 0], yloss_mean + xloss_mean)
yloss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.constant([1.0] * 16 + [0.0] * 16),
......@@ -885,21 +949,20 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
yloss_mean = tf.reduce_sum(yloss)
xloss = yloss
xloss_mean = tf.reduce_sum(xloss)
self.assertAllClose(loss[1], yloss_mean + xloss_mean)
self.assertAllClose(loss[0, 1], yloss_mean + xloss_mean)
def test_box_consistency_loss_balanced(self):
boxes_gt = tf.constant([
[0., 0., 0.49, 0.49]])
boxes_gt = tf.constant([[
[0., 0., 0.49, 0.49]]])
boxes_jittered = None
mask_prediction = np.zeros((1, 32, 32)).astype(np.float32)
mask_prediction[0] = 1.0
mask_prediction = np.zeros((1, 1, 32, 32)).astype(np.float32)
mask_prediction[0, 0] = 1.0
model = build_meta_arch(box_consistency_loss_normalize='normalize_balanced',
predict_full_resolution_masks=True)
loss_func = tf.function(
model._compute_per_instance_box_consistency_loss)
model._compute_box_consistency_loss)
loss = loss_func(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
......@@ -909,63 +972,64 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
yloss_mean = tf.reduce_sum(yloss) / 16.0
xloss_mean = yloss_mean
self.assertAllClose(loss[0], yloss_mean + xloss_mean)
self.assertAllClose(loss[0, 0], yloss_mean + xloss_mean)
def test_box_consistency_dice_loss(self):
model = build_meta_arch(use_dice_loss=True)
boxes_gt = tf.constant([[0., 0., 0.49, 1.0]])
boxes_jittered = tf.constant([[0.0, 0.0, 1.0, 1.0]])
boxes_gt = tf.constant([[[0., 0., 0.49, 1.0]]])
boxes_jittered = tf.constant([[[0.0, 0.0, 1.0, 1.0]]])
almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :24, :24] = almost_inf
mask_prediction = np.full((1, 1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, 0, :24, :24] = almost_inf
loss = model._compute_per_instance_box_consistency_loss(
loss = model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
yloss = 1 - 6.0 / 7
xloss = 0.2
self.assertAllClose(loss, [yloss + xloss])
self.assertAllClose(loss, [[yloss + xloss]])
def test_color_consistency_loss_full_res_shape(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
boxes = tf.zeros((3, 4))
img = tf.zeros((32, 32, 3))
mask_logits = tf.zeros((3, 32, 32))
boxes = tf.zeros((5, 3, 4))
img = tf.zeros((5, 32, 32, 3))
mask_logits = tf.zeros((5, 3, 32, 32))
loss = model._compute_per_instance_color_consistency_loss(
loss = model._compute_color_consistency_loss(
boxes, img, mask_logits)
self.assertEqual([3], loss.shape)
self.assertEqual([5, 3], loss.shape)
def test_color_consistency_1_threshold(self):
model = build_meta_arch(predict_full_resolution_masks=True,
color_consistency_threshold=0.99)
boxes = tf.zeros((3, 4))
img = tf.zeros((32, 32, 3))
mask_logits = tf.zeros((3, 32, 32)) - 1e4
boxes = tf.zeros((5, 3, 4))
img = tf.zeros((5, 32, 32, 3))
mask_logits = tf.zeros((5, 3, 32, 32)) - 1e4
loss = model._compute_per_instance_color_consistency_loss(
loss = model._compute_color_consistency_loss(
boxes, img, mask_logits)
self.assertAllClose(loss, np.zeros(3))
self.assertAllClose(loss, np.zeros((5, 3)))
def test_box_consistency_dice_loss_full_res(self):
model = build_meta_arch(use_dice_loss=True,
predict_full_resolution_masks=True)
boxes_gt = tf.constant([[0., 0., 1.0, 1.0]])
boxes_gt = tf.constant([[[0., 0., 1.0, 1.0]]])
boxes_jittered = None
size = 32
almost_inf = 1e10
mask_prediction = np.full((1, 32, 32), -almost_inf, dtype=np.float32)
mask_prediction[0, :16, :32] = almost_inf
mask_prediction = np.full((1, 1, size, size), -almost_inf, dtype=np.float32)
mask_prediction[0, 0, :(size // 2), :] = almost_inf
loss = model._compute_per_instance_box_consistency_loss(
loss = model._compute_box_consistency_loss(
boxes_gt, boxes_jittered, tf.constant(mask_prediction))
self.assertAlmostEqual(loss[0].numpy(), 1 / 3)
self.assertAlmostEqual(loss[0, 0].numpy(), 1 / 3)
def test_get_lab_image_shape(self):
......@@ -975,18 +1039,18 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
def test_loss_keys(self):
model = build_meta_arch(use_dice_loss=True)
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 17))] * 2,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 19))] * 2,
'object_center': [tf.random.normal((1, 8, 8, 6))] * 2,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * 2,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * 2
'preprocessed_inputs': tf.random.normal((3, 32, 32, 3)),
'MASK_LOGITS_GT_BOXES': [tf.random.normal((3, 5, 8, 8))] * 2,
'object_center': [tf.random.normal((3, 8, 8, 6))] * 2,
'box/offset': [tf.random.normal((3, 8, 8, 2))] * 2,
'box/scale': [tf.random.normal((3, 8, 8, 2))] * 2
}
model.provide_groundtruth(
groundtruth_boxes_list=[tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)],
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)],
groundtruth_weights_list=[tf.ones(5)],
groundtruth_masks_list=[tf.ones((5, 32, 32))])
groundtruth_boxes_list=[
tf.convert_to_tensor([[0., 0., 1., 1.]] * 5)] * 3,
groundtruth_classes_list=[tf.one_hot([1, 0, 1, 1, 1], depth=6)] * 3,
groundtruth_weights_list=[tf.ones(5)] * 3,
groundtruth_masks_list=[tf.ones((5, 32, 32))] * 3)
loss = model.loss(prediction, tf.constant([[32, 32, 3.0]]))
self.assertGreater(loss['Loss/deep_mask_estimation'], 0.0)
......@@ -1008,8 +1072,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
......@@ -1066,6 +1129,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
f'{mask_loss} did not respond to change in weight.')
def test_color_consistency_warmup(self):
tf.keras.backend.set_learning_phase(True)
model = build_meta_arch(
use_dice_loss=True,
predict_full_resolution_masks=True,
......@@ -1079,8 +1143,7 @@ class DeepMACMetaArchTest(tf.test.TestCase, parameterized.TestCase):
num_stages = 1
prediction = {
'preprocessed_inputs': tf.random.normal((1, 32, 32, 3)),
'INSTANCE_EMBEDDING': [tf.random.normal((1, 8, 8, 9))] * num_stages,
'PIXEL_EMBEDDING': [tf.random.normal((1, 8, 8, 8))] * num_stages,
'MASK_LOGITS_GT_BOXES': [tf.random.normal((1, 5, 8, 8))] * num_stages,
'object_center': [tf.random.normal((1, 8, 8, 6))] * num_stages,
'box/offset': [tf.random.normal((1, 8, 8, 2))] * num_stages,
'box/scale': [tf.random.normal((1, 8, 8, 2))] * num_stages
......
......@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24
// Next ID 25
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -485,6 +485,14 @@ message CenterNet {
optional int32 color_consistency_warmup_start = 23 [default=0];
// DeepMAC has been refactored to process the entire batch at once,
// instead of the previous (simple) approach of processing one sample at
// a time. Because of this, the memory consumption has increased and
// it's crucial to only feed the mask head the last stage outputs
// from the hourglass. Doing so halves the memory requirement of the
// mask head and does not cause a drop in evaluation metrics.
optional bool use_only_last_stage = 24 [default=false];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment