"examples/pytorch/vscode:/vscode.git/clone" did not exist on "36418292114b3f62b06c8d1e27c95f3a3615b44a"
Commit ad480628 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480787331
parent bbaf7373
......@@ -575,14 +575,11 @@ def resize_and_crop_boxes(boxes,
return boxes
def resize_and_crop_masks(masks,
image_scale,
output_size,
offset):
def resize_and_crop_masks(masks, image_scale, output_size, offset):
"""Resizes boxes to output size with scale and offset.
Args:
masks: `Tensor` of shape [N, H, W, 1] representing ground truth masks.
masks: `Tensor` of shape [N, H, W, C] representing ground truth masks.
image_scale: 2D float `Tensor` representing scale factors that apply to
[height, width] of input image.
output_size: 2D `Tensor` or `int` representing [height, width] of target
......@@ -591,13 +588,17 @@ def resize_and_crop_masks(masks,
boxes.
Returns:
masks: `Tensor` of shape [N, H, W, 1] representing the scaled masks.
masks: `Tensor` of shape [N, H, W, C] representing the scaled masks.
"""
with tf.name_scope('resize_and_crop_masks'):
mask_size = tf.cast(tf.shape(masks)[1:3], tf.float32)
num_channels = tf.shape(masks)[3]
# Pad masks to avoid empty mask annotations.
masks = tf.concat(
[tf.zeros([1, mask_size[0], mask_size[1], 1]), masks], axis=0)
masks = tf.concat([
tf.zeros([1, mask_size[0], mask_size[1], num_channels],
dtype=masks.dtype), masks
],
axis=0)
scaled_size = tf.cast(image_scale * mask_size, tf.int32)
scaled_masks = tf.image.resize(
......
......@@ -241,6 +241,50 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
np.array(expected_shape[:-1]) / np.array(input_shape[:-1]))
self.assertAllEqual(image_info[3], [0, 0])
def test_resize_and_crop_masks(self):
# shape: (2, 1, 4, 3)
masks = tf.constant([[[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9, 10, 11],
]], [[
[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23],
]]])
output = preprocess_ops.resize_and_crop_masks(
masks, image_scale=[2.0, 0.5], output_size=[2, 3], offset=[1, 0])
# shape: (2, 2, 3, 3)
expected_output = tf.constant([
[
[
[3, 4, 5],
[9, 10, 11],
[0, 0, 0],
],
[
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
],
[
[
[15, 16, 17],
[21, 22, 23],
[0, 0, 0],
],
[
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
],
])
self.assertAllEqual(expected_output, output)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment