Unverified Commit 797d3a5b authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

paste masks in float32 precision

parent 9ab3bf77
...@@ -110,6 +110,9 @@ class PasteMasks(tf.keras.layers.Layer): ...@@ -110,6 +110,9 @@ class PasteMasks(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
masks, boxes = inputs masks, boxes = inputs
masks_dtype = masks.dtype
masks = tf.cast(masks, dtype=tf.float32)
boxes = tf.cast(boxes, dtype=tf.float32)
y0, x0, y1, x1 = tf.split(boxes, 4, axis=1) y0, x0, y1, x1 = tf.split(boxes, 4, axis=1)
x_coords = tf.range(0, self._output_size[1], dtype=boxes.dtype) x_coords = tf.range(0, self._output_size[1], dtype=boxes.dtype)
...@@ -124,7 +127,7 @@ class PasteMasks(tf.keras.layers.Layer): ...@@ -124,7 +127,7 @@ class PasteMasks(tf.keras.layers.Layer):
tf.expand_dims(y_coords, axis=2), tf.expand_dims(y_coords, axis=2),
multiples=[1, 1, self._output_size[1]]) multiples=[1, 1, self._output_size[1]])
pasted_masks = self._grid_sampler((masks, x_coords, y_coords)) pasted_masks = self._grid_sampler((masks, x_coords, y_coords))
return pasted_masks return tf.cast(pasted_masks, dtype=masks_dtype)
def get_config(self): def get_config(self):
return self._config_dict return self._config_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment