Commit 7007d9e3 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Updating transform_input_data to resize original image. This is necessary for...

Updating transform_input_data to resize original image. This is necessary for result_dict_for_single_example(), since it expects the input image and groundtruth masks to be of the same spatial dimensions.

PiperOrigin-RevId: 189786443
parent 6f1756bc
...@@ -58,7 +58,8 @@ def transform_input_data(tensor_dict, ...@@ -58,7 +58,8 @@ def transform_input_data(tensor_dict,
Data transformation functions are applied in the following order. Data transformation functions are applied in the following order.
1. data_augmentation_fn (optional): applied on tensor_dict. 1. data_augmentation_fn (optional): applied on tensor_dict.
2. model_preprocess_fn: applied only on image tensor in tensor_dict. 2. model_preprocess_fn: applied only on image tensor in tensor_dict.
3. image_resizer_fn: applied only on instance mask tensor in tensor_dict. 3. image_resizer_fn: applied on original image and instance mask tensor in
tensor_dict.
4. one_hot_encoding: applied to classes tensor in tensor_dict. 4. one_hot_encoding: applied to classes tensor in tensor_dict.
5. merge_multiple_boxes (optional): when groundtruth boxes are exactly the 5. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
same they can be merged into a single box with an associated k-hot class same they can be merged into a single box with an associated k-hot class
...@@ -70,10 +71,11 @@ def transform_input_data(tensor_dict, ...@@ -70,10 +71,11 @@ def transform_input_data(tensor_dict,
model_preprocess_fn: model's preprocess function to apply on image tensor. model_preprocess_fn: model's preprocess function to apply on image tensor.
This function must take in a 4-D float tensor and return a 4-D preprocess This function must take in a 4-D float tensor and return a 4-D preprocess
float tensor and a tensor containing the true image shape. float tensor and a tensor containing the true image shape.
image_resizer_fn: image resizer function to apply on groundtruth instance image_resizer_fn: image resizer function to apply on original image (if
masks. This function must take a 4-D float tensor of image and a 4-D `retain_original_image` is True) and groundtruth instance masks. This
tensor of instances masks and return resized version of these along with function must take a 3-D float tensor of an image and a 3-D tensor of
the true shapes. instance masks and return a resized version of these along with the true
shapes.
num_classes: number of max classes to one-hot (or k-hot) encode the class num_classes: number of max classes to one-hot (or k-hot) encode the class
labels. labels.
data_augmentation_fn: (optional) data augmentation function to apply on data_augmentation_fn: (optional) data augmentation function to apply on
...@@ -88,17 +90,19 @@ def transform_input_data(tensor_dict, ...@@ -88,17 +90,19 @@ def transform_input_data(tensor_dict,
after applying all the transformations. after applying all the transformations.
""" """
if retain_original_image: if retain_original_image:
tensor_dict[fields.InputDataFields. original_image_resized, _ = image_resizer_fn(
original_image] = tensor_dict[fields.InputDataFields.image] tensor_dict[fields.InputDataFields.image])
tensor_dict[fields.InputDataFields.original_image] = tf.cast(
original_image_resized, tf.uint8)
# Apply data augmentation ops. # Apply data augmentation ops.
if data_augmentation_fn is not None: if data_augmentation_fn is not None:
tensor_dict = data_augmentation_fn(tensor_dict) tensor_dict = data_augmentation_fn(tensor_dict)
# Apply model preprocessing ops and resize instance masks. # Apply model preprocessing ops and resize instance masks.
image = tf.expand_dims( image = tensor_dict[fields.InputDataFields.image]
tf.to_float(tensor_dict[fields.InputDataFields.image]), axis=0) preprocessed_resized_image, true_image_shape = model_preprocess_fn(
preprocessed_resized_image, true_image_shape = model_preprocess_fn(image) tf.expand_dims(tf.to_float(image), axis=0))
tensor_dict[fields.InputDataFields.image] = tf.squeeze( tensor_dict[fields.InputDataFields.image] = tf.squeeze(
preprocessed_resized_image, axis=0) preprocessed_resized_image, axis=0)
tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze( tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
......
...@@ -462,22 +462,31 @@ class DataTransformationFnTest(tf.test.TestCase): ...@@ -462,22 +462,31 @@ class DataTransformationFnTest(tf.test.TestCase):
fields.InputDataFields.groundtruth_classes: fields.InputDataFields.groundtruth_classes:
tf.constant(np.array([3, 1], np.int32)) tf.constant(np.array([3, 1], np.int32))
} }
def fake_image_resizer_fn(image, masks): def fake_image_resizer_fn(image, masks=None):
resized_image = tf.image.resize_images(image, [8, 8]) resized_image = tf.image.resize_images(image, [8, 8])
resized_masks = tf.transpose( results = [resized_image]
tf.image.resize_images(tf.transpose(masks, [1, 2, 0]), [8, 8]), if masks is not None:
[2, 0, 1]) resized_masks = tf.transpose(
return resized_image, resized_masks, tf.shape(resized_image) tf.image.resize_images(tf.transpose(masks, [1, 2, 0]), [8, 8]),
[2, 0, 1])
results.append(resized_masks)
results.append(tf.shape(resized_image))
return results
num_classes = 3 num_classes = 3
input_transformation_fn = functools.partial( input_transformation_fn = functools.partial(
inputs.transform_input_data, inputs.transform_input_data,
model_preprocess_fn=_fake_model_preprocessor_fn, model_preprocess_fn=_fake_model_preprocessor_fn,
image_resizer_fn=fake_image_resizer_fn, image_resizer_fn=fake_image_resizer_fn,
num_classes=num_classes) num_classes=num_classes,
retain_original_image=True)
with self.test_session() as sess: with self.test_session() as sess:
transformed_inputs = sess.run( transformed_inputs = sess.run(
input_transformation_fn(tensor_dict=tensor_dict)) input_transformation_fn(tensor_dict=tensor_dict))
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].dtype, tf.uint8)
self.assertAllEqual(transformed_inputs[
fields.InputDataFields.original_image].shape, [8, 8, 3])
self.assertAllEqual(transformed_inputs[ self.assertAllEqual(transformed_inputs[
fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8]) fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment