Commit 02d00c0c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Enable multiple inputs for tf.Vision RetinaNet model.

PiperOrigin-RevId: 465434866
parent d4efa810
......@@ -13,7 +13,7 @@
# limitations under the License.
"""RetinaNet."""
from typing import Any, Mapping, List, Optional, Union
from typing import Any, Mapping, List, Optional, Union, Sequence
# Import libraries
import tensorflow as tf
......@@ -74,7 +74,7 @@ class RetinaNetModel(tf.keras.Model):
self._detection_generator = detection_generator
def call(self,
images: tf.Tensor,
images: Union[tf.Tensor, Sequence[tf.Tensor]],
image_shape: Optional[tf.Tensor] = None,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
output_intermediate_features: bool = False,
......@@ -82,8 +82,10 @@ class RetinaNetModel(tf.keras.Model):
"""Forward pass of the RetinaNet model.
Args:
images: `Tensor`, the input batched images, whose shape is
[batch, height, width, 3].
images: `Tensor` or a sequence of `Tensor`, the input batched images to
the backbone network, whose shape(s) is [batch, height, width, 3]. If it
is a sequence of `Tensor`, we will assume the anchors are generated
based on the shape of the first image(s).
image_shape: `Tensor`, the actual shape of the input images, whose shape
is [batch, 2] where the last dimension is [height, width]. Note that
this is the actual image shape excluding paddings. For example, images
......@@ -141,7 +143,16 @@ class RetinaNetModel(tf.keras.Model):
else:
# Generate anchor boxes for this batch if not provided.
if anchor_boxes is None:
_, image_height, image_width, _ = images.get_shape().as_list()
if isinstance(images, Sequence):
primary_images = images[0]
elif isinstance(images, tf.Tensor):
primary_images = images
else:
raise ValueError(
'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.'
.format(type(images)))
_, image_height, image_width, _ = primary_images.get_shape().as_list()
anchor_boxes = anchor.Anchor(
min_level=self._config_dict['min_level'],
max_level=self._config_dict['max_level'],
......@@ -152,7 +163,7 @@ class RetinaNetModel(tf.keras.Model):
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0),
[tf.shape(images)[0], 1, 1, 1])
[tf.shape(primary_images)[0], 1, 1, 1])
# Post-processing.
final_results = self.detection_generator(raw_boxes, raw_scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment