Commit 3ce2f61b authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into context_tf2

parents bb16d5ca 8e9296ff
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for colab tutorials located in object_detection/colab_tutorials/..."""
import base64
import io
import json
from typing import Dict
from typing import List
from typing import Union
import uuid
from IPython.display import display
from IPython.display import Javascript
import numpy as np
from PIL import Image
from google.colab import output
from google.colab.output import eval_js
def image_from_numpy(image):
"""Open an image at the specified path and encode it in Base64.
Args:
image: np.ndarray
Image represented as a numpy array
Returns:
An encoded Base64 representation of the image
"""
with io.BytesIO() as img_output:
Image.fromarray(image).save(img_output, format='JPEG')
data = img_output.getvalue()
data = str(base64.b64encode(data))[2:-1]
return data
def draw_bbox(image_urls, callbackId): # pylint: disable=invalid-name
"""Open the bounding box UI and send the results to a callback function.
Args:
image_urls: list[str | np.ndarray]
List of locations from where to load the images from. If a np.ndarray is
given, the array is interpretted as an image and sent to the frontend.
If a str is given, the string is interpreted as a path and is read as a
np.ndarray before being sent to the frontend.
callbackId: str
The ID for the callback function to send the bounding box results to
when the user hits submit.
"""
js = Javascript('''
async function load_image(imgs, callbackId) {
//init organizational elements
const div = document.createElement('div');
var image_cont = document.createElement('div');
var errorlog = document.createElement('div');
var crosshair_h = document.createElement('div');
crosshair_h.style.position = "absolute";
crosshair_h.style.backgroundColor = "transparent";
crosshair_h.style.width = "100%";
crosshair_h.style.height = "0px";
crosshair_h.style.zIndex = 9998;
crosshair_h.style.borderStyle = "dotted";
crosshair_h.style.borderWidth = "2px";
crosshair_h.style.borderColor = "rgba(255, 0, 0, 0.75)";
crosshair_h.style.cursor = "crosshair";
var crosshair_v = document.createElement('div');
crosshair_v.style.position = "absolute";
crosshair_v.style.backgroundColor = "transparent";
crosshair_v.style.width = "0px";
crosshair_v.style.height = "100%";
crosshair_v.style.zIndex = 9999;
crosshair_v.style.top = "0px";
crosshair_v.style.borderStyle = "dotted";
crosshair_v.style.borderWidth = "2px";
crosshair_v.style.borderColor = "rgba(255, 0, 0, 0.75)";
crosshair_v.style.cursor = "crosshair";
crosshair_v.style.marginTop = "23px";
var brdiv = document.createElement('br');
//init control elements
var next = document.createElement('button');
var prev = document.createElement('button');
var submit = document.createElement('button');
var deleteButton = document.createElement('button');
var deleteAllbutton = document.createElement('button');
//init image containers
var image = new Image();
var canvas_img = document.createElement('canvas');
var ctx = canvas_img.getContext("2d");
canvas_img.style.cursor = "crosshair";
canvas_img.setAttribute('draggable', false);
crosshair_v.setAttribute('draggable', false);
crosshair_h.setAttribute('draggable', false);
// bounding box containers
const height = 600
var allBoundingBoxes = [];
var curr_image = 0
var im_height = 0;
var im_width = 0;
//initialize bounding boxes
for (var i = 0; i < imgs.length; i++) {
allBoundingBoxes[i] = [];
}
//initialize image view
errorlog.id = 'errorlog';
image.style.display = 'block';
image.setAttribute('draggable', false);
//load the first image
img = imgs[curr_image];
image.src = "data:image/png;base64," + img;
image.onload = function() {
// normalize display height and canvas
image.height = height;
image_cont.height = canvas_img.height = image.height;
image_cont.width = canvas_img.width = image.naturalWidth;
crosshair_v.style.height = image_cont.height + "px";
crosshair_h.style.width = image_cont.width + "px";
// draw the new image
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
};
// move to next image in array
next.textContent = "next image";
next.onclick = function(){
if (curr_image < imgs.length - 1){
// clear canvas and load new image
curr_image += 1;
errorlog.innerHTML = "";
}
else{
errorlog.innerHTML = "All images completed!!";
}
resetcanvas();
}
//move forward through list of images
prev.textContent = "prev image"
prev.onclick = function(){
if (curr_image > 0){
// clear canvas and load new image
curr_image -= 1;
errorlog.innerHTML = "";
}
else{
errorlog.innerHTML = "at the beginning";
}
resetcanvas();
}
// on delete, deletes the last bounding box
deleteButton.textContent = "undo bbox";
deleteButton.onclick = function(){
boundingBoxes.pop();
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
image.src = "data:image/png;base64," + img;
image.onload = function() {
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
boundingBoxes.map(r => {drawRect(r)});
};
}
// on all delete, deletes all of the bounding box
deleteAllbutton.textContent = "delete all"
deleteAllbutton.onclick = function(){
boundingBoxes = [];
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
image.src = "data:image/png;base64," + img;
image.onload = function() {
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
//boundingBoxes.map(r => {drawRect(r)});
};
}
// on submit, send the boxes to display
submit.textContent = "submit";
submit.onclick = function(){
errorlog.innerHTML = "";
// send box data to callback fucntion
google.colab.kernel.invokeFunction(callbackId, [allBoundingBoxes], {});
}
// init template for annotations
const annotation = {
x: 0,
y: 0,
w: 0,
h: 0,
};
// the array of all rectangles
let boundingBoxes = allBoundingBoxes[curr_image];
// the actual rectangle, the one that is being drawn
let o = {};
// a variable to store the mouse position
let m = {},
// a variable to store the point where you begin to draw the
// rectangle
start = {};
// a boolean variable to store the drawing state
let isDrawing = false;
var elem = null;
function handleMouseDown(e) {
// on mouse click set change the cursor and start tracking the mouse position
start = oMousePos(canvas_img, e);
// configure is drawing to true
isDrawing = true;
}
function handleMouseMove(e) {
// move crosshairs, but only within the bounds of the canvas
if (document.elementsFromPoint(e.pageX, e.pageY).includes(canvas_img)) {
crosshair_h.style.top = e.pageY + "px";
crosshair_v.style.left = e.pageX + "px";
}
// move the bounding box
if(isDrawing){
m = oMousePos(canvas_img, e);
draw();
}
}
function handleMouseUp(e) {
if (isDrawing) {
// on mouse release, push a bounding box to array and draw all boxes
isDrawing = false;
const box = Object.create(annotation);
// calculate the position of the rectangle
if (o.w > 0){
box.x = o.x;
}
else{
box.x = o.x + o.w;
}
if (o.h > 0){
box.y = o.y;
}
else{
box.y = o.y + o.h;
}
box.w = Math.abs(o.w);
box.h = Math.abs(o.h);
// add the bounding box to the image
boundingBoxes.push(box);
draw();
}
}
function draw() {
o.x = (start.x)/image.width; // start position of x
o.y = (start.y)/image.height; // start position of y
o.w = (m.x - start.x)/image.width; // width
o.h = (m.y - start.y)/image.height; // height
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
// draw all the rectangles saved in the rectsRy
boundingBoxes.map(r => {drawRect(r)});
// draw the actual rectangle
drawRect(o);
}
// add the handlers needed for dragging
crosshair_h.addEventListener("mousedown", handleMouseDown);
crosshair_v.addEventListener("mousedown", handleMouseDown);
document.addEventListener("mousemove", handleMouseMove);
document.addEventListener("mouseup", handleMouseUp);
function resetcanvas(){
// clear canvas
ctx.clearRect(0, 0, canvas_img.width, canvas_img.height);
img = imgs[curr_image]
image.src = "data:image/png;base64," + img;
// onload init new canvas and display image
image.onload = function() {
// normalize display height and canvas
image.height = height;
image_cont.height = canvas_img.height = image.height;
image_cont.width = canvas_img.width = image.naturalWidth;
crosshair_v.style.height = image_cont.height + "px";
crosshair_h.style.width = image_cont.width + "px";
// draw the new image
ctx.drawImage(image, 0, 0, image.naturalWidth, image.naturalHeight, 0, 0, canvas_img.width, canvas_img.height);
// draw bounding boxes
boundingBoxes = allBoundingBoxes[curr_image];
boundingBoxes.map(r => {drawRect(r)});
};
}
function drawRect(o){
// draw a predefined rectangle
ctx.strokeStyle = "red";
ctx.lineWidth = 2;
ctx.beginPath(o);
ctx.rect(o.x * image.width, o.y * image.height, o.w * image.width, o.h * image.height);
ctx.stroke();
}
// Function to detect the mouse position
function oMousePos(canvas_img, evt) {
let ClientRect = canvas_img.getBoundingClientRect();
return {
x: evt.clientX - ClientRect.left,
y: evt.clientY - ClientRect.top
};
}
//configure colab output display
google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
//build the html document that will be seen in output
div.appendChild(document.createElement('br'))
div.appendChild(image_cont)
image_cont.appendChild(canvas_img)
image_cont.appendChild(crosshair_h)
image_cont.appendChild(crosshair_v)
div.appendChild(document.createElement('br'))
div.appendChild(errorlog)
div.appendChild(prev)
div.appendChild(next)
div.appendChild(deleteButton)
div.appendChild(deleteAllbutton)
div.appendChild(document.createElement('br'))
div.appendChild(brdiv)
div.appendChild(submit)
document.querySelector("#output-area").appendChild(div);
return
}''')
# load the images as a byte array
bytearrays = []
for image in image_urls:
if isinstance(image, np.ndarray):
bytearrays.append(image_from_numpy(image))
else:
raise TypeError('Image has unsupported type {}.'.format(type(image)))
# format arrays for input
image_data = json.dumps(bytearrays)
del bytearrays
# call java script function pass string byte array(image_data) as input
display(js)
eval_js('load_image({}, \'{}\')'.format(image_data, callbackId))
return
def annotate(imgs: List[Union[str, np.ndarray]], # pylint: disable=invalid-name
box_storage_pointer: List[np.ndarray],
callbackId: str = None):
"""Open the bounding box UI and prompt the user for input.
Args:
imgs: list[str | np.ndarray]
List of locations from where to load the images from. If a np.ndarray is
given, the array is interpretted as an image and sent to the frontend. If
a str is given, the string is interpreted as a path and is read as a
np.ndarray before being sent to the frontend.
box_storage_pointer: list[np.ndarray]
Destination list for bounding box arrays. Each array in this list
corresponds to one of the images given in imgs. The array is a
N x 4 array where N is the number of bounding boxes given by the user
for that particular image. If there are no bounding boxes for an image,
None is used instead of an empty array.
callbackId: str, optional
The ID for the callback function that communicates between the fontend
and the backend. If no ID is given, a random UUID string is used instead.
"""
# Set a random ID for the callback function
if callbackId is None:
callbackId = str(uuid.uuid1()).replace('-', '')
def dictToList(input_bbox): # pylint: disable=invalid-name
"""Convert bbox.
This function converts the dictionary from the frontend (if the format
{x, y, w, h} as shown in callbackFunction) into a list
([y_min, x_min, y_max, x_max])
Args:
input_bbox:
Returns:
A list with bbox coordinates in the form [ymin, xmin, ymax, xmax].
"""
return (input_bbox['y'], input_bbox['x'], input_bbox['y'] + input_bbox['h'],
input_bbox['x'] + input_bbox['w'])
def callbackFunction(annotations: List[List[Dict[str, float]]]): # pylint: disable=invalid-name
"""Callback function.
This is the call back function to capture the data from the frontend and
convert the data into a numpy array.
Args:
annotations: list[list[dict[str, float]]]
The input of the call back function is a list of list of objects
corresponding to the annotations. The format of annotations is shown
below
[
// stuff for image 1
[
// stuff for rect 1
{x, y, w, h},
// stuff for rect 2
{x, y, w, h},
...
],
// stuff for image 2
[
// stuff for rect 1
{x, y, w, h},
// stuff for rect 2
{x, y, w, h},
...
],
...
]
"""
# reset the boxes list
nonlocal box_storage_pointer
boxes: List[np.ndarray] = box_storage_pointer
boxes.clear()
# load the new annotations into the boxes list
for annotations_per_img in annotations:
rectangles_as_arrays = [np.clip(dictToList(annotation), 0, 1)
for annotation in annotations_per_img]
if rectangles_as_arrays:
boxes.append(np.stack(rectangles_as_arrays))
else:
boxes.append(None)
# output the annotations to the errorlog
with output.redirect_to_element('#errorlog'):
display('--boxes array populated--')
output.register_callback(callbackId, callbackFunction)
draw_bbox(imgs, callbackId)
...@@ -799,14 +799,14 @@ def position_sensitive_crop_regions(image, ...@@ -799,14 +799,14 @@ def position_sensitive_crop_regions(image,
def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
image_width): image_width, resize_method='bilinear'):
"""Transforms the box masks back to full image masks. """Transforms the box masks back to full image masks.
Embeds masks in bounding boxes of larger masks whose shapes correspond to Embeds masks in bounding boxes of larger masks whose shapes correspond to
image shape. image shape.
Args: Args:
box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width]. box_masks: A tensor of size [num_masks, mask_height, mask_width].
boxes: A tf.float32 tensor of size [num_masks, 4] containing the box boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
corners. Row i contains [ymin, xmin, ymax, xmax] of the box corners. Row i contains [ymin, xmin, ymax, xmax] of the box
corresponding to mask i. Note that the box corners are in corresponding to mask i. Note that the box corners are in
...@@ -815,10 +815,14 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, ...@@ -815,10 +815,14 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
the image height. the image height.
image_width: Image width. The output mask will have the same width as the image_width: Image width. The output mask will have the same width as the
image width. image width.
resize_method: The resize method, either 'bilinear' or 'nearest'. Note that
'bilinear' is only respected if box_masks is a float.
Returns: Returns:
A tf.float32 tensor of size [num_masks, image_height, image_width]. A tensor of size [num_masks, image_height, image_width] with the same dtype
as `box_masks`.
""" """
resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method
# TODO(rathodv): Make this a public function. # TODO(rathodv): Make this a public function.
def reframe_box_masks_to_image_masks_default(): def reframe_box_masks_to_image_masks_default():
"""The default function when there are more than 0 box masks.""" """The default function when there are more than 0 box masks."""
...@@ -840,16 +844,19 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height, ...@@ -840,16 +844,19 @@ def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
# TODO(vighneshb) Use matmul_crop_and_resize so that the output shape # TODO(vighneshb) Use matmul_crop_and_resize so that the output shape
# is static. This will help us run and test on TPUs. # is static. This will help us run and test on TPUs.
return tf.image.crop_and_resize( resized_crops = tf.image.crop_and_resize(
image=box_masks_expanded, image=box_masks_expanded,
boxes=reverse_boxes, boxes=reverse_boxes,
box_ind=tf.range(num_boxes), box_ind=tf.range(num_boxes),
crop_size=[image_height, image_width], crop_size=[image_height, image_width],
extrapolation_value=0.0) method=resize_method,
extrapolation_value=0)
return tf.cast(resized_crops, box_masks.dtype)
image_masks = tf.cond( image_masks = tf.cond(
tf.shape(box_masks)[0] > 0, tf.shape(box_masks)[0] > 0,
reframe_box_masks_to_image_masks_default, reframe_box_masks_to_image_masks_default,
lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32)) lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))
return tf.squeeze(image_masks, axis=3) return tf.squeeze(image_masks, axis=3)
......
...@@ -18,6 +18,8 @@ from __future__ import absolute_import ...@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import six import six
from six.moves import range from six.moves import range
...@@ -1190,36 +1192,59 @@ class OpsTestBatchPositionSensitiveCropRegions(test_case.TestCase): ...@@ -1190,36 +1192,59 @@ class OpsTestBatchPositionSensitiveCropRegions(test_case.TestCase):
# The following tests are only executed on CPU because the output # The following tests are only executed on CPU because the output
# shape is not constant. # shape is not constant.
class ReframeBoxMasksToImageMasksTest(test_case.TestCase): class ReframeBoxMasksToImageMasksTest(test_case.TestCase,
parameterized.TestCase):
def testZeroImageOnEmptyMask(self):
@parameterized.parameters(
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'bilinear'},
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'nearest'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'bilinear'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'nearest'},
)
def testZeroImageOnEmptyMask(self, mask_dtype, mask_dtype_np, resize_method):
np_expected_image_masks = np.array([[[0, 0, 0, 0], np_expected_image_masks = np.array([[[0, 0, 0, 0],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]]], dtype=np.float32) [0, 0, 0, 0]]])
def graph_fn(): def graph_fn():
box_masks = tf.constant([[[0, 0], box_masks = tf.constant([[[0, 0],
[0, 0]]], dtype=tf.float32) [0, 0]]], dtype=mask_dtype)
boxes = tf.constant([[0.0, 0.0, 1.0, 1.0]], dtype=tf.float32) boxes = tf.constant([[0.0, 0.0, 1.0, 1.0]], dtype=tf.float32)
image_masks = ops.reframe_box_masks_to_image_masks(box_masks, boxes, image_masks = ops.reframe_box_masks_to_image_masks(
image_height=4, box_masks, boxes, image_height=4, image_width=4,
image_width=4) resize_method=resize_method)
return image_masks return image_masks
np_image_masks = self.execute_cpu(graph_fn, []) np_image_masks = self.execute_cpu(graph_fn, [])
self.assertEqual(np_image_masks.dtype, mask_dtype_np)
self.assertAllClose(np_image_masks, np_expected_image_masks) self.assertAllClose(np_image_masks, np_expected_image_masks)
def testZeroBoxMasks(self): @parameterized.parameters(
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'bilinear'},
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'nearest'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'bilinear'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'nearest'},
)
def testZeroBoxMasks(self, mask_dtype, mask_dtype_np, resize_method):
def graph_fn(): def graph_fn():
box_masks = tf.zeros([0, 3, 3], dtype=tf.float32) box_masks = tf.zeros([0, 3, 3], dtype=mask_dtype)
boxes = tf.zeros([0, 4], dtype=tf.float32) boxes = tf.zeros([0, 4], dtype=tf.float32)
image_masks = ops.reframe_box_masks_to_image_masks(box_masks, boxes, image_masks = ops.reframe_box_masks_to_image_masks(
image_height=4, box_masks, boxes, image_height=4, image_width=4,
image_width=4) resize_method=resize_method)
return image_masks return image_masks
np_image_masks = self.execute_cpu(graph_fn, []) np_image_masks = self.execute_cpu(graph_fn, [])
self.assertEqual(np_image_masks.dtype, mask_dtype_np)
self.assertAllEqual(np_image_masks.shape, np.array([0, 4, 4])) self.assertAllEqual(np_image_masks.shape, np.array([0, 4, 4]))
def testBoxWithZeroArea(self): def testBoxWithZeroArea(self):
...@@ -1235,40 +1260,70 @@ class ReframeBoxMasksToImageMasksTest(test_case.TestCase): ...@@ -1235,40 +1260,70 @@ class ReframeBoxMasksToImageMasksTest(test_case.TestCase):
np_image_masks = self.execute_cpu(graph_fn, []) np_image_masks = self.execute_cpu(graph_fn, [])
self.assertAllEqual(np_image_masks.shape, np.array([1, 4, 4])) self.assertAllEqual(np_image_masks.shape, np.array([1, 4, 4]))
def testMaskIsCenteredInImageWhenBoxIsCentered(self): @parameterized.parameters(
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'bilinear'},
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'nearest'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'bilinear'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'nearest'},
)
def testMaskIsCenteredInImageWhenBoxIsCentered(self, mask_dtype,
mask_dtype_np, resize_method):
def graph_fn(): def graph_fn():
box_masks = tf.constant([[[1, 1], box_masks = tf.constant([[[4, 4],
[1, 1]]], dtype=tf.float32) [4, 4]]], dtype=mask_dtype)
boxes = tf.constant([[0.25, 0.25, 0.75, 0.75]], dtype=tf.float32) boxes = tf.constant([[0.25, 0.25, 0.75, 0.75]], dtype=tf.float32)
image_masks = ops.reframe_box_masks_to_image_masks(box_masks, boxes, image_masks = ops.reframe_box_masks_to_image_masks(
image_height=4, box_masks, boxes, image_height=4, image_width=4,
image_width=4) resize_method=resize_method)
return image_masks return image_masks
np_expected_image_masks = np.array([[[0, 0, 0, 0], np_expected_image_masks = np.array([[[0, 0, 0, 0],
[0, 1, 1, 0], [0, 4, 4, 0],
[0, 1, 1, 0], [0, 4, 4, 0],
[0, 0, 0, 0]]], dtype=np.float32) [0, 0, 0, 0]]], dtype=mask_dtype_np)
np_image_masks = self.execute_cpu(graph_fn, []) np_image_masks = self.execute_cpu(graph_fn, [])
self.assertEqual(np_image_masks.dtype, mask_dtype_np)
self.assertAllClose(np_image_masks, np_expected_image_masks) self.assertAllClose(np_image_masks, np_expected_image_masks)
def testMaskOffCenterRemainsOffCenterInImage(self): @parameterized.parameters(
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'bilinear'},
{'mask_dtype': tf.float32, 'mask_dtype_np': np.float32,
'resize_method': 'nearest'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'bilinear'},
{'mask_dtype': tf.uint8, 'mask_dtype_np': np.uint8,
'resize_method': 'nearest'},
)
def testMaskOffCenterRemainsOffCenterInImage(self, mask_dtype,
mask_dtype_np, resize_method):
def graph_fn(): def graph_fn():
box_masks = tf.constant([[[1, 0], box_masks = tf.constant([[[1, 0],
[0, 1]]], dtype=tf.float32) [0, 1]]], dtype=mask_dtype)
boxes = tf.constant([[0.25, 0.5, 0.75, 1.0]], dtype=tf.float32) boxes = tf.constant([[0.25, 0.5, 0.75, 1.0]], dtype=tf.float32)
image_masks = ops.reframe_box_masks_to_image_masks(box_masks, boxes, image_masks = ops.reframe_box_masks_to_image_masks(
image_height=4, box_masks, boxes, image_height=4, image_width=4,
image_width=4) resize_method=resize_method)
return image_masks return image_masks
np_expected_image_masks = np.array([[[0, 0, 0, 0], if mask_dtype == tf.float32 and resize_method == 'bilinear':
[0, 0, 0.6111111, 0.16666669], np_expected_image_masks = np.array([[[0, 0, 0, 0],
[0, 0, 0.3888889, 0.83333337], [0, 0, 0.6111111, 0.16666669],
[0, 0, 0, 0]]], dtype=np.float32) [0, 0, 0.3888889, 0.83333337],
[0, 0, 0, 0]]], dtype=np.float32)
else:
np_expected_image_masks = np.array([[[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 0, 0, 0]]], dtype=mask_dtype_np)
np_image_masks = self.execute_cpu(graph_fn, []) np_image_masks = self.execute_cpu(graph_fn, [])
self.assertEqual(np_image_masks.dtype, mask_dtype_np)
self.assertAllClose(np_image_masks, np_expected_image_masks) self.assertAllClose(np_image_masks, np_expected_image_masks)
......
...@@ -790,6 +790,81 @@ def draw_side_by_side_evaluation_image(eval_dict, ...@@ -790,6 +790,81 @@ def draw_side_by_side_evaluation_image(eval_dict,
return images_with_detections_list return images_with_detections_list
def draw_densepose_visualizations(eval_dict,
max_boxes_to_draw=20,
min_score_thresh=0.2,
num_parts=24,
dp_coord_to_visualize=0):
"""Draws DensePose visualizations.
Args:
eval_dict: The evaluation dictionary returned by
eval_util.result_dict_for_batched_example().
max_boxes_to_draw: The maximum number of boxes to draw for detections.
min_score_thresh: The minimum score threshold for showing detections.
num_parts: The number of different densepose parts.
dp_coord_to_visualize: Whether to visualize v-coordinates (0) or
u-coordinates (0) overlaid on the person masks.
Returns:
A list of [1, H, W, C] uint8 tensor, each element corresponding to an image
in the batch.
Raises:
ValueError: If `dp_coord_to_visualize` is not 0 or 1.
"""
if dp_coord_to_visualize not in (0, 1):
raise ValueError('`dp_coord_to_visualize` must be either 0 for v '
'coordinates), or 1 for u coordinates, but instead got '
'{}'.format(dp_coord_to_visualize))
detection_fields = fields.DetectionResultFields()
input_data_fields = fields.InputDataFields()
if detection_fields.detection_masks not in eval_dict:
raise ValueError('Expected `detection_masks` in `eval_dict`.')
if detection_fields.detection_surface_coords not in eval_dict:
raise ValueError('Expected `detection_surface_coords` in `eval_dict`.')
images_with_detections_list = []
for indx in range(eval_dict[input_data_fields.original_image].shape[0]):
# Note that detection masks have already been resized to the original image
# shapes, but `original_image` has not.
# TODO(ronnyvotel): Consider resizing `original_image` in
# eval_util.result_dict_for_batched_example().
true_shape = eval_dict[input_data_fields.true_image_shape][indx]
original_shape = eval_dict[
input_data_fields.original_image_spatial_shape][indx]
image = eval_dict[input_data_fields.original_image][indx]
image = shape_utils.pad_or_clip_nd(image, [true_shape[0], true_shape[1], 3])
image = _resize_original_image(image, original_shape)
scores = eval_dict[detection_fields.detection_scores][indx]
detection_masks = eval_dict[detection_fields.detection_masks][indx]
surface_coords = eval_dict[detection_fields.detection_surface_coords][indx]
def draw_densepose_py_func(image, detection_masks, surface_coords, scores):
"""Overlays part masks and surface coords on original images."""
surface_coord_image = np.copy(image)
for i, (score, surface_coord, mask) in enumerate(
zip(scores, surface_coords, detection_masks)):
if i == max_boxes_to_draw:
break
if score > min_score_thresh:
draw_part_mask_on_image_array(image, mask, num_parts=num_parts)
draw_float_channel_on_image_array(
surface_coord_image, surface_coord[:, :, dp_coord_to_visualize],
mask)
return np.concatenate([image, surface_coord_image], axis=1)
image_with_densepose = tf.py_func(
draw_densepose_py_func,
[image, detection_masks, surface_coords, scores],
tf.uint8)
images_with_detections_list.append(
image_with_densepose[tf.newaxis, :, :, :])
return images_with_detections_list
def draw_keypoints_on_image_array(image, def draw_keypoints_on_image_array(image,
keypoints, keypoints,
keypoint_scores=None, keypoint_scores=None,
...@@ -918,8 +993,6 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): ...@@ -918,8 +993,6 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
raise ValueError('`image` not of type np.uint8') raise ValueError('`image` not of type np.uint8')
if mask.dtype != np.uint8: if mask.dtype != np.uint8:
raise ValueError('`mask` not of type np.uint8') raise ValueError('`mask` not of type np.uint8')
if np.any(np.logical_and(mask != 1, mask != 0)):
raise ValueError('`mask` elements should be in [0, 1]')
if image.shape[:2] != mask.shape: if image.shape[:2] != mask.shape:
raise ValueError('The image has spatial dimensions %s but the mask has ' raise ValueError('The image has spatial dimensions %s but the mask has '
'dimensions %s' % (image.shape[:2], mask.shape)) 'dimensions %s' % (image.shape[:2], mask.shape))
...@@ -929,11 +1002,85 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): ...@@ -929,11 +1002,85 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
solid_color = np.expand_dims( solid_color = np.expand_dims(
np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') pil_mask = Image.fromarray(np.uint8(255.0*alpha*(mask > 0))).convert('L')
pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
np.copyto(image, np.array(pil_image.convert('RGB'))) np.copyto(image, np.array(pil_image.convert('RGB')))
def draw_part_mask_on_image_array(image, mask, alpha=0.4, num_parts=24):
"""Draws part mask on an image.
Args:
image: uint8 numpy array with shape (img_height, img_height, 3)
mask: a uint8 numpy array of shape (img_height, img_height) with
1-indexed parts (0 for background).
alpha: transparency value between 0 and 1 (default: 0.4)
num_parts: the maximum number of parts that may exist in the image (default
24 for DensePose).
Raises:
ValueError: On incorrect data type for image or masks.
"""
if image.dtype != np.uint8:
raise ValueError('`image` not of type np.uint8')
if mask.dtype != np.uint8:
raise ValueError('`mask` not of type np.uint8')
if image.shape[:2] != mask.shape:
raise ValueError('The image has spatial dimensions %s but the mask has '
'dimensions %s' % (image.shape[:2], mask.shape))
pil_image = Image.fromarray(image)
part_colors = np.zeros_like(image)
mask_1_channel = mask[:, :, np.newaxis]
for i, color in enumerate(STANDARD_COLORS[:num_parts]):
rgb = np.array(ImageColor.getrgb(color), dtype=np.uint8)
part_colors += (mask_1_channel == i + 1) * rgb[np.newaxis, np.newaxis, :]
pil_part_colors = Image.fromarray(np.uint8(part_colors)).convert('RGBA')
pil_mask = Image.fromarray(np.uint8(255.0 * alpha * (mask > 0))).convert('L')
pil_image = Image.composite(pil_part_colors, pil_image, pil_mask)
np.copyto(image, np.array(pil_image.convert('RGB')))
def draw_float_channel_on_image_array(image, channel, mask, alpha=0.9,
cmap='YlGn'):
"""Draws a floating point channel on an image array.
Args:
image: uint8 numpy array with shape (img_height, img_height, 3)
channel: float32 numpy array with shape (img_height, img_height). The values
should be in the range [0, 1], and will be mapped to colors using the
provided colormap `cmap` argument.
mask: a uint8 numpy array of shape (img_height, img_height) with
1-indexed parts (0 for background).
alpha: transparency value between 0 and 1 (default: 0.9)
cmap: string with the colormap to use.
Raises:
ValueError: On incorrect data type for image or masks.
"""
if image.dtype != np.uint8:
raise ValueError('`image` not of type np.uint8')
if channel.dtype != np.float32:
raise ValueError('`channel` not of type np.float32')
if mask.dtype != np.uint8:
raise ValueError('`mask` not of type np.uint8')
if image.shape[:2] != channel.shape:
raise ValueError('The image has spatial dimensions %s but the channel has '
'dimensions %s' % (image.shape[:2], channel.shape))
if image.shape[:2] != mask.shape:
raise ValueError('The image has spatial dimensions %s but the mask has '
'dimensions %s' % (image.shape[:2], mask.shape))
cm = plt.get_cmap(cmap)
pil_image = Image.fromarray(image)
colored_channel = cm(channel)[:, :, :3]
pil_colored_channel = Image.fromarray(
np.uint8(colored_channel * 255)).convert('RGBA')
pil_mask = Image.fromarray(np.uint8(255.0 * alpha * (mask > 0))).convert('L')
pil_image = Image.composite(pil_colored_channel, pil_image, pil_mask)
np.copyto(image, np.array(pil_image.convert('RGB')))
def visualize_boxes_and_labels_on_image_array( def visualize_boxes_and_labels_on_image_array(
image, image,
boxes, boxes,
...@@ -973,8 +1120,8 @@ def visualize_boxes_and_labels_on_image_array( ...@@ -973,8 +1120,8 @@ def visualize_boxes_and_labels_on_image_array(
boxes and plot all boxes as black with no classes or scores. boxes and plot all boxes as black with no classes or scores.
category_index: a dict containing category dictionaries (each holding category_index: a dict containing category dictionaries (each holding
category index `id` and category name `name`) keyed by category indices. category index `id` and category name `name`) keyed by category indices.
instance_masks: a numpy array of shape [N, image_height, image_width] with instance_masks: a uint8 numpy array of shape [N, image_height, image_width],
values ranging between 0 and 1, can be None. can be None.
instance_boundaries: a numpy array of shape [N, image_height, image_width] instance_boundaries: a numpy array of shape [N, image_height, image_width]
with values ranging between 0 and 1, can be None. with values ranging between 0 and 1, can be None.
keypoints: a numpy array of shape [N, num_keypoints, 2], can keypoints: a numpy array of shape [N, num_keypoints, 2], can
......
...@@ -373,6 +373,38 @@ class VisualizationUtilsTest(test_case.TestCase): ...@@ -373,6 +373,38 @@ class VisualizationUtilsTest(test_case.TestCase):
color='Blue', alpha=.5) color='Blue', alpha=.5)
self.assertAllEqual(test_image, expected_result) self.assertAllEqual(test_image, expected_result)
def test_draw_part_mask_on_image_array(self):
test_image = np.asarray([[[0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0]]], dtype=np.uint8)
mask = np.asarray([[0, 1],
[1, 6]], dtype=np.uint8)
visualization_utils.draw_part_mask_on_image_array(test_image, mask,
alpha=.5)
self.assertAllEqual([0, 0, 0], test_image[0, 0])
self.assertAllGreater(test_image[0, 1], 0)
self.assertAllGreater(test_image[1, 0], 0)
self.assertAllGreater(test_image[1, 1], 0)
self.assertAllEqual(test_image[0, 1], test_image[1, 0])
def test_draw_float_channel_on_image_array(self):
test_image = np.asarray([[[0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0]]], dtype=np.uint8)
channel = np.asarray([[0., 0.5],
[0., 1.]], dtype=np.float32)
mask = np.asarray([[0, 1],
[1, 1]], dtype=np.uint8)
# The colormap ('bwr') maps the values as follows:
# 0.0 -> Blue
# 0.5 -> White
# 1.0 -> Red
visualization_utils.draw_float_channel_on_image_array(
test_image, channel, mask, alpha=1.0, cmap='bwr')
expected_result = np.asarray([[[0, 0, 0], [255, 254, 254]],
[[0, 0, 255], [255, 0, 0]]], dtype=np.uint8)
self.assertAllEqual(test_image, expected_result)
def test_draw_heatmaps_on_image(self): def test_draw_heatmaps_on_image(self):
test_image = self.create_colorful_test_image() test_image = self.create_colorful_test_image()
test_image = Image.fromarray(test_image) test_image = Image.fromarray(test_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment