Unverified Commit 09d9656f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents ac671306 49a5706c
......@@ -31,7 +31,10 @@ from object_detection.utils import shape_utils
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from official.vision.image_classification.efficientnet import efficientnet_model
try:
from official.legacy.image_classification.efficientnet import efficientnet_model
except ModuleNotFoundError:
from official.vision.image_classification.efficientnet import efficientnet_model
_EFFICIENTNET_LEVEL_ENDPOINTS = {
1: 'stack_0/block_0/project_bn',
......
......@@ -3,9 +3,6 @@ import os
from setuptools import find_packages
from setuptools import setup
# Note: adding apache-beam to required packages causes conflict with
# tf-models-offical requirements. These packages request for incompatible
# oauth2client package.
REQUIRED_PACKAGES = [
# Required for apache-beam with PY3
'avro-python3',
......@@ -23,9 +20,7 @@ REQUIRED_PACKAGES = [
'pandas',
'tf-models-official>=2.5.1',
'tensorflow_io',
# Workaround due to
# https://github.com/keras-team/keras/issues/15583
'keras==2.6.0'
'keras'
]
setup(
......
......@@ -403,6 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -471,6 +472,19 @@ message CenterNet {
optional float color_consistency_loss_weight = 19 [default=0.0];
optional LossNormalize box_consistency_loss_normalize = 20 [
default=NORMALIZE_AUTO];
// If set, will use the bounding box tightness prior approach. This means
// that the max will be restricted to only be inside the box for both
// dimensions. See details here:
// https://papers.nips.cc/paper/2019/hash/e6e713296627dff6475085cc6a224464-Abstract.html
optional bool box_consistency_tightness = 21 [default=false];
optional int32 color_consistency_warmup_steps = 22 [default=0];
optional int32 color_consistency_warmup_start = 23 [default=0];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......@@ -483,6 +497,12 @@ message CenterNet {
optional PostProcessing post_processing = 24;
}
enum LossNormalize {
NORMALIZE_AUTO = 0; // SUM for 2D inputs (dice loss) and MEAN for others.
NORMALIZE_GROUNDTRUTH_COUNT = 1;
NORMALIZE_BALANCED = 3;
}
message CenterNetFeatureExtractor {
optional string type = 1;
......
......@@ -870,6 +870,9 @@ class OpenImagesChallengeEvaluator(OpenImagesDetectionEvaluator):
image_classes = groundtruth_dict[input_fields.groundtruth_image_classes]
elif input_fields.groundtruth_labeled_classes in groundtruth_dict:
image_classes = groundtruth_dict[input_fields.groundtruth_labeled_classes]
else:
logging.warning('No image classes field found for image with id %s!',
image_id)
image_classes -= self._label_id_offset
self._evaluatable_labels[image_id] = np.unique(
np.concatenate((image_classes, groundtruth_classes)))
......
......@@ -236,6 +236,77 @@ def compute_floor_offsets_with_indices(y_source,
return offsets, indices
def coordinates_to_iou(y_grid, x_grid, blist,
channels_onehot, weights=None):
"""Computes a per-pixel IoU with groundtruth boxes.
At each pixel, we return the IoU assuming that we predicted the
ideal height and width for the box at that location.
Args:
y_grid: A 2D tensor with shape [height, width] which contains the grid
y-coordinates given in the (output) image dimensions.
x_grid: A 2D tensor with shape [height, width] which contains the grid
x-coordinates given in the (output) image dimensions.
blist: A BoxList object with `num_instances` number of boxes.
channels_onehot: A 2D tensor with shape [num_instances, num_channels]
representing the one-hot encoded channel labels for each point.
weights: A 1D tensor with shape [num_instances] corresponding to the
weight of each instance.
Returns:
iou_heatmap: A [height, width, num_channels] shapes float tensor denoting
the IoU based heatmap.
"""
image_height, image_width = tf.shape(y_grid)[0], tf.shape(y_grid)[1]
num_pixels = image_height * image_width
_, _, height, width = blist.get_center_coordinates_and_sizes()
num_boxes = tf.shape(height)[0]
per_pixel_ymin = (y_grid[tf.newaxis, :, :] -
(height[:, tf.newaxis, tf.newaxis] / 2.0))
per_pixel_xmin = (x_grid[tf.newaxis, :, :] -
(width[:, tf.newaxis, tf.newaxis] / 2.0))
per_pixel_ymax = (y_grid[tf.newaxis, :, :] +
(height[:, tf.newaxis, tf.newaxis] / 2.0))
per_pixel_xmax = (x_grid[tf.newaxis, :, :] +
(width[:, tf.newaxis, tf.newaxis] / 2.0))
# [num_boxes, height, width] -> [num_boxes * height * width]
per_pixel_ymin = tf.reshape(
per_pixel_ymin, [num_pixels * num_boxes])
per_pixel_xmin = tf.reshape(
per_pixel_xmin, [num_pixels * num_boxes])
per_pixel_ymax = tf.reshape(
per_pixel_ymax, [num_pixels * num_boxes])
per_pixel_xmax = tf.reshape(
per_pixel_xmax, [num_pixels * num_boxes])
per_pixel_blist = box_list.BoxList(
tf.stack([per_pixel_ymin, per_pixel_xmin,
per_pixel_ymax, per_pixel_xmax], axis=1))
target_boxes = tf.tile(
blist.get()[:, tf.newaxis, :], [1, num_pixels, 1])
# [num_boxes, height * width, 4] -> [num_boxes * height * wdith, 4]
target_boxes = tf.reshape(target_boxes,
[num_pixels * num_boxes, 4])
target_blist = box_list.BoxList(target_boxes)
ious = box_list_ops.matched_iou(target_blist, per_pixel_blist)
ious = tf.reshape(ious, [num_boxes, image_height, image_width])
per_class_iou = (
ious[:, :, :, tf.newaxis] *
channels_onehot[:, tf.newaxis, tf.newaxis, :])
if weights is not None:
per_class_iou = (
per_class_iou * weights[:, tf.newaxis, tf.newaxis, tf.newaxis])
per_class_iou = tf.maximum(per_class_iou, 0.0)
return tf.reduce_max(per_class_iou, axis=0)
def get_valid_keypoint_mask_for_class(keypoint_coordinates,
class_id,
class_onehot,
......
......@@ -18,6 +18,7 @@ from absl.testing import parameterized
import numpy as np
import tensorflow.compat.v1 as tf
from object_detection.core import box_list
from object_detection.utils import target_assigner_utils as ta_utils
from object_detection.utils import test_case
......@@ -265,6 +266,31 @@ class TargetUtilTest(parameterized.TestCase, test_case.TestCase):
np.array([[0.0, 3.0, 4.0, 0.0, 4.0]]))
self.assertAllEqual(valid, [[False, True, True, False, True]])
def test_coordinates_to_iou(self):
def graph_fn():
y, x = tf.meshgrid(tf.range(32, dtype=tf.float32),
tf.range(32, dtype=tf.float32))
blist = box_list.BoxList(
tf.constant([[0., 0., 32., 32.],
[0., 0., 16., 16.],
[0.0, 0.0, 4.0, 4.0]]))
classes = tf.constant([[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.]])
result = ta_utils.coordinates_to_iou(
y, x, blist, classes)
return result
result = self.execute(graph_fn, [])
self.assertEqual(result.shape, (32, 32, 3))
self.assertAlmostEqual(result[0, 0, 0], 1.0 / 7.0)
self.assertAlmostEqual(result[0, 0, 1], 1.0 / 7.0)
self.assertAlmostEqual(result[0, 16, 0], 1.0 / 7.0)
self.assertAlmostEqual(result[2, 2, 2], 1.0)
self.assertAlmostEqual(result[8, 8, 2], 0.0)
if __name__ == '__main__':
tf.test.main()
......@@ -1294,7 +1294,7 @@ def add_cdf_image_summary(values, name):
fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32)
/ cumulative_values.size)
fig = plt.figure(frameon=False)
ax = fig.add_subplot('111')
ax = fig.add_subplot(1, 1, 1)
ax.plot(fraction_of_examples, cumulative_values)
ax.set_ylabel('cumulative normalized values')
ax.set_xlabel('fraction of examples')
......@@ -1321,7 +1321,7 @@ def add_hist_image_summary(values, bins, name):
def hist_plot(values, bins):
"""Numpy function to plot hist."""
fig = plt.figure(frameon=False)
ax = fig.add_subplot('111')
ax = fig.add_subplot(1, 1, 1)
y, x = np.histogram(values, bins=bins)
ax.plot(x[:-1], y)
ax.set_ylabel('count')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment