Commit c67aad59 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support None num_boxes and refactor serving.

PiperOrigin-RevId: 400773326
parent c68dbef0
...@@ -151,7 +151,7 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -151,7 +151,7 @@ class MaskRCNNModel(tf.keras.Model):
model_mask_outputs = self._call_mask_outputs( model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs, model_box_outputs=model_outputs,
features=intermediate_outputs['features'], features=model_outputs['decoder_features'],
current_rois=intermediate_outputs['current_rois'], current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'], matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'], matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
...@@ -161,6 +161,15 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -161,6 +161,15 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs.update(model_mask_outputs) model_outputs.update(model_mask_outputs)
return model_outputs return model_outputs
def _get_backbone_and_decoder_features(self, images):
backbone_features = self.backbone(images)
if self.decoder:
features = self.decoder(backbone_features)
else:
features = backbone_features
return backbone_features, features
def _call_box_outputs( def _call_box_outputs(
self, images: tf.Tensor, self, images: tf.Tensor,
image_shape: tf.Tensor, image_shape: tf.Tensor,
...@@ -173,18 +182,15 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -173,18 +182,15 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs = {} model_outputs = {}
# Feature extraction. # Feature extraction.
backbone_features = self.backbone(images) (backbone_features,
if self.decoder: decoder_features) = self._get_backbone_and_decoder_features(images)
features = self.decoder(backbone_features)
else:
features = backbone_features
# Region proposal network. # Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(features) rpn_scores, rpn_boxes = self.rpn_head(decoder_features)
model_outputs.update({ model_outputs.update({
'backbone_features': backbone_features, 'backbone_features': backbone_features,
'decoder_features': features, 'decoder_features': decoder_features,
'rpn_boxes': rpn_boxes, 'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores 'rpn_scores': rpn_scores
}) })
...@@ -219,7 +225,7 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -219,7 +225,7 @@ class MaskRCNNModel(tf.keras.Model):
(class_outputs, box_outputs, model_outputs, matched_gt_boxes, (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices, matched_gt_classes, matched_gt_indices,
current_rois) = self._run_frcnn_head( current_rois) = self._run_frcnn_head(
features=features, features=decoder_features,
rois=current_rois, rois=current_rois,
gt_boxes=gt_boxes, gt_boxes=gt_boxes,
gt_classes=gt_classes, gt_classes=gt_classes,
...@@ -270,7 +276,6 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -270,7 +276,6 @@ class MaskRCNNModel(tf.keras.Model):
'matched_gt_boxes': matched_gt_boxes, 'matched_gt_boxes': matched_gt_boxes,
'matched_gt_indices': matched_gt_indices, 'matched_gt_indices': matched_gt_indices,
'matched_gt_classes': matched_gt_classes, 'matched_gt_classes': matched_gt_classes,
'features': features,
'current_rois': current_rois, 'current_rois': current_rois,
} }
return (model_outputs, intermediate_outputs) return (model_outputs, intermediate_outputs)
...@@ -302,19 +307,16 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -302,19 +307,16 @@ class MaskRCNNModel(tf.keras.Model):
current_rois = model_outputs['detection_boxes'] current_rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes'] roi_classes = model_outputs['detection_classes']
# Mask RoI align. mask_logits, mask_probs = self._features_to_mask_outputs(
mask_roi_features = self.mask_roi_aligner(features, current_rois) features, current_rois, roi_classes)
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
if training: if training:
model_outputs.update({ model_outputs.update({
'mask_outputs': raw_masks, 'mask_outputs': mask_logits,
}) })
else: else:
model_outputs.update({ model_outputs.update({
'detection_masks': tf.math.sigmoid(raw_masks), 'detection_masks': mask_probs,
}) })
return model_outputs return model_outputs
...@@ -395,6 +397,15 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -395,6 +397,15 @@ class MaskRCNNModel(tf.keras.Model):
return (class_outputs, box_outputs, model_outputs, matched_gt_boxes, return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices, rois) matched_gt_classes, matched_gt_indices, rois)
def _features_to_mask_outputs(self, features, rois, roi_classes):
# Mask RoI align.
mask_roi_features = self.mask_roi_aligner(features, rois)
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
return raw_masks, tf.nn.sigmoid(raw_masks)
@property @property
def checkpoint_items( def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]: self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
......
...@@ -43,10 +43,11 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x): ...@@ -43,10 +43,11 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
[batch_size, num_boxes, output_size, output_size, num_filters]. [batch_size, num_boxes, output_size, output_size, num_filters].
""" """
batch_size, num_boxes, output_size, _, num_filters = ( features_shape = tf.shape(features)
features.get_shape().as_list()) batch_size, num_boxes, output_size, num_filters = (
if batch_size is None: features_shape[0], features_shape[1], features_shape[2],
batch_size = tf.shape(features)[0] features_shape[4])
output_size = output_size // 2 output_size = output_size // 2
kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1]) kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1])
kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2]) kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2])
...@@ -88,7 +89,8 @@ def _compute_grid_positions(boxes, boundaries, output_size, sample_offset): ...@@ -88,7 +89,8 @@ def _compute_grid_positions(boxes, boundaries, output_size, sample_offset):
box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2] box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2] box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
""" """
batch_size, num_boxes, _ = boxes.get_shape().as_list() boxes_shape = tf.shape(boxes)
batch_size, num_boxes = boxes_shape[0], boxes_shape[1]
if batch_size is None: if batch_size is None:
batch_size = tf.shape(boxes)[0] batch_size = tf.shape(boxes)[0]
box_grid_x = [] box_grid_x = []
...@@ -161,11 +163,12 @@ def multilevel_crop_and_resize(features, ...@@ -161,11 +163,12 @@ def multilevel_crop_and_resize(features,
levels = list(features.keys()) levels = list(features.keys())
min_level = int(min(levels)) min_level = int(min(levels))
max_level = int(max(levels)) max_level = int(max(levels))
features_shape = tf.shape(features[str(min_level)])
batch_size, max_feature_height, max_feature_width, num_filters = ( batch_size, max_feature_height, max_feature_width, num_filters = (
features[str(min_level)].get_shape().as_list()) features_shape[0], features_shape[1], features_shape[2],
if batch_size is None: features_shape[3])
batch_size = tf.shape(features[str(min_level)])[0]
_, num_boxes, _ = boxes.get_shape().as_list() num_boxes = tf.shape(boxes)[1]
# Stack feature pyramid into a features_all of shape # Stack feature pyramid into a features_all of shape
# [batch_size, levels, height, width, num_filters]. # [batch_size, levels, height, width, num_filters].
......
...@@ -131,7 +131,7 @@ class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -131,7 +131,7 @@ class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
model_mask_outputs = self._call_mask_outputs( model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs, model_box_outputs=model_outputs,
features=intermediate_outputs['features'], features=model_outputs['decoder_features'],
current_rois=intermediate_outputs['current_rois'], current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'], matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'], matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# Lint as: python3 # Lint as: python3
"""Detection input and model functions for serving/inference.""" """Detection input and model functions for serving/inference."""
from typing import Mapping, Text
import tensorflow as tf import tensorflow as tf
from official.vision.beta import configs from official.vision.beta import configs
...@@ -78,13 +79,17 @@ class DetectionModule(export_base.ExportModule): ...@@ -78,13 +79,17 @@ class DetectionModule(export_base.ExportModule):
return image, anchor_boxes, image_info return image, anchor_boxes, image_info
def serve(self, images: tf.Tensor): def preprocess(self, images: tf.Tensor) -> (
"""Cast image to float and run inference. tf.Tensor, Mapping[Text, tf.Tensor], tf.Tensor):
"""Preprocess inputs to be suitable for the model.
Args: Args:
images: uint8 Tensor of shape [batch_size, None, None, 3] images: The images tensor.
Returns: Returns:
Tensor holding detection output logits. images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
""" """
model_params = self.params.task.model model_params = self.params.task.model
with tf.device('cpu:0'): with tf.device('cpu:0'):
...@@ -117,6 +122,18 @@ class DetectionModule(export_base.ExportModule): ...@@ -117,6 +122,18 @@ class DetectionModule(export_base.ExportModule):
image_info_spec), image_info_spec),
parallel_iterations=32)) parallel_iterations=32))
return images, anchor_boxes, image_info
def serve(self, images: tf.Tensor):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
images, anchor_boxes, image_info = self.preprocess(images)
input_image_shape = image_info[:, 1, :] input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that # To overcome keras.Model extra limitation to save a model with layers that
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment