Unverified Commit 39678c06 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

precompute grid coordinates

parent 797d3a5b
......@@ -108,15 +108,16 @@ class PasteMasks(tf.keras.layers.Layer):
'grid_sampler': grid_sampler
}
def build(self, input_shape):
self._x_coords = tf.range(0, self._output_size[1], dtype=tf.float32)
self._y_coords = tf.range(0, self._output_size[0], dtype=tf.float32)
def call(self, inputs):
masks, boxes = inputs
masks_dtype = masks.dtype
masks = tf.cast(masks, dtype=tf.float32)
boxes = tf.cast(boxes, dtype=tf.float32)
y0, x0, y1, x1 = tf.split(boxes, 4, axis=1)
x_coords = tf.range(0, self._output_size[1], dtype=boxes.dtype)
y_coords = tf.range(0, self._output_size[0], dtype=boxes.dtype)
x_coords = tf.cast(self._x_coords, dtype=boxes.dtype)
y_coords = tf.cast(self._y_coords, dtype=boxes.dtype)
x_coords = (x_coords - x0) / (x1 - x0) * 2 - 1
y_coords = (y_coords - y0) / (y1 - y0) * 2 - 1
......@@ -127,7 +128,7 @@ class PasteMasks(tf.keras.layers.Layer):
tf.expand_dims(y_coords, axis=2),
multiples=[1, 1, self._output_size[1]])
pasted_masks = self._grid_sampler((masks, x_coords, y_coords))
return tf.cast(pasted_masks, dtype=masks_dtype)
return pasted_masks
def get_config(self):
return self._config_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment