Commit 0d6ce602 authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Allow batch size > 1 with exporter_main_v2.py

PiperOrigin-RevId: 354429469
parent c02bca43
...@@ -56,8 +56,7 @@ class FakeModel(model.DetectionModel): ...@@ -56,8 +56,7 @@ class FakeModel(model.DetectionModel):
value=conv_weight_scalar)) value=conv_weight_scalar))
def preprocess(self, inputs): def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model. return tf.identity(inputs), exporter_lib_v2.get_true_shapes(inputs)
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes): def predict(self, preprocessed_inputs, true_image_shapes):
return {'image': self._conv(preprocessed_inputs)} return {'image': self._conv(preprocessed_inputs)}
......
...@@ -54,8 +54,7 @@ class FakeModel(model.DetectionModel): ...@@ -54,8 +54,7 @@ class FakeModel(model.DetectionModel):
value=conv_weight_scalar)) value=conv_weight_scalar))
def preprocess(self, inputs): def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model. return tf.identity(inputs), exporter_lib_v2.get_true_shapes(inputs)
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes): def predict(self, preprocessed_inputs, true_image_shapes):
return {'image': self._conv(preprocessed_inputs)} return {'image': self._conv(preprocessed_inputs)}
......
...@@ -51,8 +51,7 @@ class FakeModel(model.DetectionModel): ...@@ -51,8 +51,7 @@ class FakeModel(model.DetectionModel):
value=conv_weight_scalar)) value=conv_weight_scalar))
def preprocess(self, inputs): def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model. return tf.identity(inputs), exporter_lib_v2.get_true_shapes(inputs)
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes, **side_inputs): def predict(self, preprocessed_inputs, true_image_shapes, **side_inputs):
return_dict = {'image': self._conv(preprocessed_inputs)} return_dict = {'image': self._conv(preprocessed_inputs)}
......
...@@ -94,22 +94,37 @@ class DetectionInferenceModule(tf.Module): ...@@ -94,22 +94,37 @@ class DetectionInferenceModule(tf.Module):
def _get_side_names_from_zip(self, zipped_side_inputs): def _get_side_names_from_zip(self, zipped_side_inputs):
return [side[2] for side in zipped_side_inputs] return [side[2] for side in zipped_side_inputs]
def _run_inference_on_images(self, image, **kwargs): def _preprocess_input(self, batch_input, decode_fn):
# Input preprocessing happends on the CPU. We don't need to use the device
# placement as it is automatically handled by TF.
def _decode_and_preprocess(single_input):
image = decode_fn(single_input)
image = tf.cast(image, tf.float32)
image, true_shape = self._model.preprocess(image[tf.newaxis, :, :, :])
return image[0], true_shape[0]
images, true_shapes = tf.map_fn(
_decode_and_preprocess,
elems=batch_input,
parallel_iterations=32,
back_prop=False,
fn_output_signature=(tf.float32, tf.int32))
return images, true_shapes
def _run_inference_on_images(self, images, true_shapes, **kwargs):
"""Cast image to float and run inference. """Cast image to float and run inference.
Args: Args:
image: uint8 Tensor of shape [1, None, None, 3]. images: float32 Tensor of shape [None, None, None, 3].
true_shapes: int32 Tensor of form [batch, 3]
**kwargs: additional keyword arguments. **kwargs: additional keyword arguments.
Returns: Returns:
Tensor dictionary holding detections. Tensor dictionary holding detections.
""" """
label_id_offset = 1 label_id_offset = 1
prediction_dict = self._model.predict(images, true_shapes, **kwargs)
image = tf.cast(image, tf.float32) detections = self._model.postprocess(prediction_dict, true_shapes)
image, shapes = self._model.preprocess(image)
prediction_dict = self._model.predict(image, shapes, **kwargs)
detections = self._model.postprocess(prediction_dict, shapes)
classes_field = fields.DetectionResultFields.detection_classes classes_field = fields.DetectionResultFields.detection_classes
detections[classes_field] = ( detections[classes_field] = (
tf.cast(detections[classes_field], tf.float32) + label_id_offset) tf.cast(detections[classes_field], tf.float32) + label_id_offset)
...@@ -144,7 +159,8 @@ class DetectionFromImageModule(DetectionInferenceModule): ...@@ -144,7 +159,8 @@ class DetectionFromImageModule(DetectionInferenceModule):
def call_func(input_tensor, *side_inputs): def call_func(input_tensor, *side_inputs):
kwargs = dict(zip(self._side_input_names, side_inputs)) kwargs = dict(zip(self._side_input_names, side_inputs))
return self._run_inference_on_images(input_tensor, **kwargs) images, true_shapes = self._preprocess_input(input_tensor, lambda x: x)
return self._run_inference_on_images(images, true_shapes, **kwargs)
self.__call__ = tf.function(call_func, input_signature=sig) self.__call__ = tf.function(call_func, input_signature=sig)
...@@ -154,44 +170,43 @@ class DetectionFromImageModule(DetectionInferenceModule): ...@@ -154,44 +170,43 @@ class DetectionFromImageModule(DetectionInferenceModule):
zipped_side_inputs) zipped_side_inputs)
def get_true_shapes(input_tensor):
input_shape = tf.shape(input_tensor)
batch = input_shape[0]
image_shape = input_shape[1:]
true_shapes = tf.tile(image_shape[tf.newaxis, :], [batch, 1])
return true_shapes
class DetectionFromFloatImageModule(DetectionInferenceModule): class DetectionFromFloatImageModule(DetectionInferenceModule):
"""Detection Inference Module for float image inputs.""" """Detection Inference Module for float image inputs."""
@tf.function( @tf.function(
input_signature=[ input_signature=[
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.float32)]) tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)])
def __call__(self, input_tensor): def __call__(self, input_tensor):
return self._run_inference_on_images(input_tensor) images, true_shapes = self._preprocess_input(input_tensor, lambda x: x)
return self._run_inference_on_images(images,
true_shapes)
class DetectionFromEncodedImageModule(DetectionInferenceModule): class DetectionFromEncodedImageModule(DetectionInferenceModule):
"""Detection Inference Module for encoded image string inputs.""" """Detection Inference Module for encoded image string inputs."""
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)]) @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def __call__(self, input_tensor): def __call__(self, input_tensor):
with tf.device('cpu:0'): images, true_shapes = self._preprocess_input(input_tensor, _decode_image)
image = tf.map_fn( return self._run_inference_on_images(images, true_shapes)
_decode_image,
elems=input_tensor,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False)
return self._run_inference_on_images(image)
class DetectionFromTFExampleModule(DetectionInferenceModule): class DetectionFromTFExampleModule(DetectionInferenceModule):
"""Detection Inference Module for TF.Example inputs.""" """Detection Inference Module for TF.Example inputs."""
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)]) @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def __call__(self, input_tensor): def __call__(self, input_tensor):
with tf.device('cpu:0'): images, true_shapes = self._preprocess_input(input_tensor,
image = tf.map_fn( _decode_tf_example)
_decode_tf_example, return self._run_inference_on_images(images, true_shapes)
elems=input_tensor,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False)
return self._run_inference_on_images(image)
DETECTION_MODULE_MAP = { DETECTION_MODULE_MAP = {
'image_tensor': DetectionFromImageModule, 'image_tensor': DetectionFromImageModule,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment