Unverified Commit edee549f authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added `_parse_data` with `is_training` flag

parent 7e4f1ef3
......@@ -149,6 +149,7 @@ class Parser(parser.Parser):
self._groundtruth_padded_size[0],
self._groundtruth_padded_size[1])
mask -= 1
# Assign ignore label to the padded region.
mask = tf.where(
tf.equal(mask, -1),
......@@ -157,8 +158,7 @@ class Parser(parser.Parser):
mask = tf.squeeze(mask, axis=0)
return mask
def _parse_train_data(self, data):
"""Parses data for training."""
def _parse_data(self, data, is_training):
image = data['image']
image = preprocess_ops.normalize_image(image)
......@@ -170,7 +170,7 @@ class Parser(parser.Parser):
dtype=tf.float32)
# Flips image randomly during training.
if self._aug_rand_hflip:
if self._aug_rand_hflip and is_training:
masks = tf.stack([category_mask, instance_mask], axis=0)
image, _, masks = preprocess_ops.random_horizontal_flip(
image=image, masks=masks)
......@@ -182,22 +182,22 @@ class Parser(parser.Parser):
image,
self._output_size,
self._output_size,
aug_scale_min=self._aug_scale_min,
aug_scale_max=self._aug_scale_max)
aug_scale_min=self._aug_scale_min if is_training else 1.0,
aug_scale_max=self._aug_scale_max if is_training else 1.0)
category_mask = self._resize_and_crop_mask(
category_mask,
image_info,
is_training=True)
is_training=is_training)
instance_mask = self._resize_and_crop_mask(
instance_mask,
image_info,
is_training=True)
is_training=is_training)
centers_heatmap, centers_offset = self._encode_centers_and_offets(
instance_mask=instance_mask[:, :, 0])
# Cast image as self._dtype
# Cast image and labels as self._dtype
image = tf.cast(image, dtype=self._dtype)
category_mask = tf.cast(category_mask, dtype=self._dtype)
instance_mask = tf.cast(instance_mask, dtype=self._dtype)
......@@ -216,57 +216,13 @@ class Parser(parser.Parser):
}
return image, labels
def _parse_train_data(self, data):
"""Parses data for training."""
return self._parse_data(data=data, is_training=True)
def _parse_eval_data(self, data):
"""Parses data for evaluation."""
image = data['image']
image = preprocess_ops.normalize_image(image)
# shape of masks: [H, W]
category_mask = tf.cast(
data['groundtruth_panoptic_category_mask'][:, :, 0],
dtype=tf.float32)
instance_mask = tf.cast(
data['groundtruth_panoptic_instance_mask'][:, :, 0],
dtype=tf.float32)
# Resizes and crops image.
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._output_size,
self._output_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
category_mask = self._resize_and_crop_mask(
category_mask,
image_info,
is_training=False)
instance_mask = self._resize_and_crop_mask(
instance_mask,
image_info,
is_training=False)
centers_heatmap, centers_offset = self._encode_centers_and_offets(
instance_mask=instance_mask[:, :, 0])
# Cast image as self._dtype
image = tf.cast(image, dtype=self._dtype)
category_mask = tf.cast(category_mask, dtype=self._dtype)
instance_mask = tf.cast(instance_mask, dtype=self._dtype)
centers_heatmap = tf.cast(centers_heatmap, dtype=self._dtype)
centers_offset = tf.cast(centers_offset, dtype=self._dtype)
things_mask = tf.cast(
tf.not_equal(instance_mask, self._ignore_label),
dtype=self._dtype)
labels = {
'category_mask': category_mask,
'instance_mask': instance_mask,
'centers_heatmap': centers_heatmap,
'centers_offset': centers_offset,
'things_mask': things_mask
}
return image, labels
return self._parse_data(data=data, is_training=False)
def _encode_centers_and_offets(self, instance_mask):
"""Generates center heatmaps and offets from instance id mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment