Commit 532b946c authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add image+box module in exporter.

PiperOrigin-RevId: 363794263
parent dfaf525e
...@@ -75,6 +75,13 @@ class FakeModel(model.DetectionModel): ...@@ -75,6 +75,13 @@ class FakeModel(model.DetectionModel):
} }
return postprocessed_tensors return postprocessed_tensors
def predict_masks_from_boxes(self, prediction_dict, true_image_shapes, boxes):
output_dict = self.postprocess(prediction_dict, true_image_shapes)
output_dict.update({
'detection_masks': tf.ones(shape=(1, 2, 16), dtype=tf.float32),
})
return output_dict
def restore_map(self, checkpoint_path, fine_tune_checkpoint_type): def restore_map(self, checkpoint_path, fine_tune_checkpoint_type):
pass pass
...@@ -291,6 +298,83 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): ...@@ -291,6 +298,83 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
[[150 + 0.7, 150 + 0.6], [150 + 0.9, 150 + 0.0]]) [[150 + 0.7, 150 + 0.6], [150 + 0.9, 150 + 0.0]])
class DetectionFromImageAndBoxModuleTest(tf.test.TestCase):
def get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor' or input_type == 'image_and_boxes_tensor':
return np.zeros((1, 20, 20, 3), dtype=np.uint8)
if input_type == 'float_image_tensor':
return np.zeros((1, 20, 20, 3), dtype=np.float32)
elif input_type == 'encoded_image_string_tensor':
image = Image.new('RGB', (20, 20))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((20, 20, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/source_id':
dataset_util.bytes_feature(six.b('image_id')),
})).SerializeToString()
return [example]
def _save_checkpoint_from_mock_model(self,
checkpoint_dir,
conv_weight_scalar=6.0):
mock_model = FakeModel(conv_weight_scalar)
fake_image = tf.zeros(shape=[1, 10, 10, 3], dtype=tf.float32)
preprocessed_inputs, true_image_shapes = mock_model.preprocess(fake_image)
predictions = mock_model.predict(preprocessed_inputs, true_image_shapes)
mock_model.postprocess(predictions, true_image_shapes)
ckpt = tf.train.Checkpoint(model=mock_model)
exported_checkpoint_manager = tf.train.CheckpointManager(
ckpt, checkpoint_dir, max_to_keep=1)
exported_checkpoint_manager.save(checkpoint_number=0)
def test_export_saved_model_and_run_inference_for_segmentation(
self, input_type='image_and_boxes_tensor'):
tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(tmp_dir)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
input_type=input_type,
pipeline_config=pipeline_config,
trained_checkpoint_dir=tmp_dir,
output_directory=output_directory)
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
image = self.get_dummy_input(input_type)
boxes = tf.constant([
[
[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8],
],
])
detections = detect_fn(tf.constant(image), boxes)
detection_fields = fields.DetectionResultFields
self.assertIn(detection_fields.detection_masks, detections)
self.assertListEqual(
list(detections[detection_fields.detection_masks].shape), [1, 2, 16])
if __name__ == '__main__': if __name__ == '__main__':
tf.enable_v2_behavior() tf.enable_v2_behavior()
tf.test.main() tf.test.main()
...@@ -288,3 +288,73 @@ def export_inference_graph(input_type, ...@@ -288,3 +288,73 @@ def export_inference_graph(input_type,
signatures=concrete_function) signatures=concrete_function)
config_util.save_pipeline_config(pipeline_config, output_directory) config_util.save_pipeline_config(pipeline_config, output_directory)
class DetectionFromImageAndBoxModule(DetectionInferenceModule):
"""Detection Inference Module for image with bounding box inputs.
The saved model will require two inputs (image and normalized boxes) and run
per-box mask prediction. To be compatible with this exporter, the detection
model has to implement a called predict_masks_from_boxes(
prediction_dict, true_image_shapes, provided_boxes, **params), where
- prediciton_dict is a dict returned by the predict method.
- true_image_shapes is a tensor of size [batch_size, 3], containing the
true shape of each image in case it is padded.
- provided_boxes is a [batch_size, num_boxes, 4] size tensor containing
boxes specified in normalized coordinates.
"""
def __init__(self,
detection_model,
use_side_inputs=False,
zipped_side_inputs=None):
"""Initializes a module for detection.
Args:
detection_model: the detection model to use for inference.
use_side_inputs: whether to use side inputs.
zipped_side_inputs: the zipped side inputs.
"""
assert hasattr(detection_model, 'predict_masks_from_boxes')
super(DetectionFromImageAndBoxModule,
self).__init__(detection_model, use_side_inputs, zipped_side_inputs)
def _run_segmentation_on_images(self, image, boxes, **kwargs):
"""Run segmentation on images with provided boxes.
Args:
image: uint8 Tensor of shape [1, None, None, 3].
boxes: float32 tensor of shape [1, None, 4] containing normalized box
coordinates.
**kwargs: additional keyword arguments.
Returns:
Tensor dictionary holding detections (including masks).
"""
label_id_offset = 1
image = tf.cast(image, tf.float32)
image, shapes = self._model.preprocess(image)
prediction_dict = self._model.predict(image, shapes, **kwargs)
detections = self._model.predict_masks_from_boxes(prediction_dict, shapes,
boxes)
classes_field = fields.DetectionResultFields.detection_classes
detections[classes_field] = (
tf.cast(detections[classes_field], tf.float32) + label_id_offset)
for key, val in detections.items():
detections[key] = tf.cast(val, tf.float32)
return detections
@tf.function(input_signature=[
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8),
tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32)
])
def __call__(self, input_tensor, boxes):
return self._run_segmentation_on_images(input_tensor, boxes)
DETECTION_MODULE_MAP.update({
'image_and_boxes_tensor': DetectionFromImageAndBoxModule,
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment