mask_ops.py 1.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for masks."""

import tensorflow as tf


def resize_and_rescale_offsets(input_tensor: tf.Tensor, target_size):
  """Bilinearly resizes and rescales the offsets.
22
23
24
25

    Reference:
    https://github.com/google-research/deeplab2/blob/main/model/utils.py#L157

26
27
28
29
  Args:
    input_tensor: A tf.Tensor of shape [batch, height, width, 2].
    target_size: A list or tuple or 1D tf.Tensor that specifies the height and
      width after resizing.
30

31
32
33
34
35
36
37
38
39
  Returns:
    The input_tensor resized to shape `[batch, target_height, target_width, 2]`.
      Moreover, the offsets along the y-axis are rescaled by a factor equal to
      (target_height - 1) / (reference_height - 1) and the offsets along the
      x-axis are rescaled by a factor equal to
      (target_width - 1) / (reference_width - 1).
  """
  input_size_y = tf.shape(input_tensor)[1]
  input_size_x = tf.shape(input_tensor)[2]
40
  dtype = input_tensor.dtype
41

42
43
44
45
  scale_y = tf.cast(target_size[0] - 1, dtype=dtype) / tf.cast(
      input_size_y - 1, dtype=dtype)
  scale_x = tf.cast(target_size[1] - 1, dtype=dtype) / tf.cast(
      input_size_x - 1, dtype=dtype)
46
47
48
49
50

  target_y, target_x = tf.split(
      value=input_tensor, num_or_size_splits=2, axis=3)
  target_y *= scale_y
  target_x *= scale_x
51
  _ = tf.concat([target_y, target_x], 3)
52
53
54
55
  return tf.image.resize(
      input_tensor,
      size=target_size,
      method=tf.image.ResizeMethod.BILINEAR)