Commit 02ff7788 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Allow mask rcnn to be run with mixed precision without NaNs.

Some parts of the forward pass would previously overflow in float16. Such parts are now done in float32.

PiperOrigin-RevId: 380857663
parent f13895b9
...@@ -198,7 +198,8 @@ def multilevel_crop_and_resize(features, ...@@ -198,7 +198,8 @@ def multilevel_crop_and_resize(features,
# Assigns boxes to the right level. # Assigns boxes to the right level.
box_width = boxes[:, :, 3] - boxes[:, :, 1] box_width = boxes[:, :, 3] - boxes[:, :, 1]
box_height = boxes[:, :, 2] - boxes[:, :, 0] box_height = boxes[:, :, 2] - boxes[:, :, 0]
areas_sqrt = tf.cast(tf.sqrt(box_height * box_width), tf.float32) areas_sqrt = tf.sqrt(
tf.cast(box_height, tf.float32) * tf.cast(box_width, tf.float32))
levels = tf.cast( levels = tf.cast(
tf.math.floordiv( tf.math.floordiv(
tf.math.log(tf.divide(areas_sqrt, 224.0)), tf.math.log(tf.divide(areas_sqrt, 224.0)),
...@@ -456,6 +457,12 @@ def crop_mask_in_target_box(masks, ...@@ -456,6 +457,12 @@ def crop_mask_in_target_box(masks,
[batch_size, num_boxes, output_size, output_size]. [batch_size, num_boxes, output_size, output_size].
""" """
with tf.name_scope('crop_mask_in_target_box'): with tf.name_scope('crop_mask_in_target_box'):
# Cast to float32, as the y_transform and other transform variables may
# overflow in float16
masks = tf.cast(masks, tf.float32)
boxes = tf.cast(boxes, tf.float32)
target_boxes = tf.cast(target_boxes, tf.float32)
batch_size, num_masks, height, width = masks.get_shape().as_list() batch_size, num_masks, height, width = masks.get_shape().as_list()
if batch_size is None: if batch_size is None:
batch_size = tf.shape(masks)[0] batch_size = tf.shape(masks)[0]
......
...@@ -132,6 +132,9 @@ class IouSimilarity: ...@@ -132,6 +132,9 @@ class IouSimilarity:
Output shape: Output shape:
[M, N], or [B, M, N] [M, N], or [B, M, N]
""" """
boxes_1 = tf.cast(boxes_1, tf.float32)
boxes_2 = tf.cast(boxes_2, tf.float32)
boxes_1_rank = len(boxes_1.shape) boxes_1_rank = len(boxes_1.shape)
boxes_2_rank = len(boxes_2.shape) boxes_2_rank = len(boxes_2.shape)
if boxes_1_rank < 2 or boxes_1_rank > 3: if boxes_1_rank < 2 or boxes_1_rank > 3:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment