Unverified Commit 051f1f0f authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

rename label keys

parent 4f8826ac
...@@ -194,25 +194,31 @@ class Parser(parser.Parser): ...@@ -194,25 +194,31 @@ class Parser(parser.Parser):
image_info, image_info,
is_training=is_training) is_training=is_training)
centers_heatmap, centers_offset = self._encode_centers_and_offets( instance_centers_heatmap, instance_centers_offset = self._encode_centers_and_offets(
instance_mask=instance_mask[:, :, 0]) instance_mask=instance_mask[:, :, 0])
# Cast image and labels as self._dtype # Cast image and labels as self._dtype
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
category_mask = tf.cast(category_mask, dtype=self._dtype) category_mask = tf.cast(category_mask, dtype=self._dtype)
instance_mask = tf.cast(instance_mask, dtype=self._dtype) instance_mask = tf.cast(instance_mask, dtype=self._dtype)
centers_heatmap = tf.cast(centers_heatmap, dtype=self._dtype) instance_centers_heatmap = tf.cast(
centers_offset = tf.cast(centers_offset, dtype=self._dtype) instance_centers_heatmap, dtype=self._dtype)
things_mask = tf.cast( instance_centers_offset = tf.cast(
tf.not_equal(instance_mask, self._ignore_label), instance_centers_offset, dtype=self._dtype)
dtype=self._dtype)
valid_mask = tf.not_equal(
category_mask, self._ignore_label)
things_mask = tf.not_equal(
instance_mask, self._ignore_label)
labels = { labels = {
'category_mask': category_mask, 'category_mask': category_mask,
'instance_mask': instance_mask, 'instance_mask': instance_mask,
'centers_heatmap': centers_heatmap, 'instance_centers_heatmap': instance_centers_heatmap,
'centers_offset': centers_offset, 'instance_centers_offset': instance_centers_offset,
'things_mask': things_mask 'valid_mask': valid_mask,
'things_mask': things_mask,
'image_info': image_info
} }
return image, labels return image, labels
...@@ -231,8 +237,8 @@ class Parser(parser.Parser): ...@@ -231,8 +237,8 @@ class Parser(parser.Parser):
instance_mask: `tf.Tensor` of shape [height, width] representing instance_mask: `tf.Tensor` of shape [height, width] representing
groundtruth instance id mask. groundtruth instance id mask.
Returns: Returns:
centers_heatmap: `tf.Tensor` of shape [height, width, 1] instance_centers_heatmap: `tf.Tensor` of shape [height, width, 1]
centers_offset: `tf.Tensor` of shape [height, width, 2] instance_centers_offset: `tf.Tensor` of shape [height, width, 2]
""" """
shape = tf.shape(instance_mask) shape = tf.shape(instance_mask)
height, width = shape[0], shape[1] height, width = shape[0], shape[1]
...@@ -244,7 +250,7 @@ class Parser(parser.Parser): ...@@ -244,7 +250,7 @@ class Parser(parser.Parser):
# as size = int(6 * sigma + 3) # as size = int(6 * sigma + 3)
padding = padding_start + padding_end padding = padding_start + padding_end
centers_heatmap = tf.zeros( instance_centers_heatmap = tf.zeros(
shape=[height + padding, width + padding], shape=[height + padding, width + padding],
dtype=tf.float32) dtype=tf.float32)
centers_offset_y = tf.zeros( centers_offset_y = tf.zeros(
...@@ -278,8 +284,8 @@ class Parser(parser.Parser): ...@@ -278,8 +284,8 @@ class Parser(parser.Parser):
indices, shape=[2, gaussian_size * gaussian_size]) indices, shape=[2, gaussian_size * gaussian_size])
indices = tf.transpose(indices) indices = tf.transpose(indices)
centers_heatmap = tf.tensor_scatter_nd_max( instance_centers_heatmap = tf.tensor_scatter_nd_max(
tensor=centers_heatmap, tensor=instance_centers_heatmap,
indices=indices, indices=indices,
updates=self._gaussian) updates=self._gaussian)
...@@ -293,13 +299,13 @@ class Parser(parser.Parser): ...@@ -293,13 +299,13 @@ class Parser(parser.Parser):
indices=tf.cast(mask_indices, dtype=tf.int32), indices=tf.cast(mask_indices, dtype=tf.int32),
updates=tf.cast(mask_center_x, dtype=tf.float32) - mask_indices[:, 1]) updates=tf.cast(mask_center_x, dtype=tf.float32) - mask_indices[:, 1])
centers_heatmap = centers_heatmap[ instance_centers_heatmap = instance_centers_heatmap[
padding_start:padding_start + height, padding_start:padding_start + height,
padding_start:padding_start + width] padding_start:padding_start + width]
centers_heatmap = tf.expand_dims(centers_heatmap, axis=-1) instance_centers_heatmap = tf.expand_dims(instance_centers_heatmap, axis=-1)
centers_offset = tf.stack( instance_centers_offset = tf.stack(
[centers_offset_y, centers_offset_x], [centers_offset_y, centers_offset_x],
axis=-1) axis=-1)
return centers_heatmap, centers_offset return instance_centers_heatmap, instance_centers_offset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment