"vscode:/vscode.git/clone" did not exist on "172640ef140e6ecfd149c17734fda67bc472d38c"
Commit c67aad59 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support None num_boxes and refactor serving.

PiperOrigin-RevId: 400773326
parent c68dbef0
......@@ -151,7 +151,7 @@ class MaskRCNNModel(tf.keras.Model):
model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs,
features=intermediate_outputs['features'],
features=model_outputs['decoder_features'],
current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
......@@ -161,6 +161,15 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs.update(model_mask_outputs)
return model_outputs
def _get_backbone_and_decoder_features(self, images):
backbone_features = self.backbone(images)
if self.decoder:
features = self.decoder(backbone_features)
else:
features = backbone_features
return backbone_features, features
def _call_box_outputs(
self, images: tf.Tensor,
image_shape: tf.Tensor,
......@@ -173,18 +182,15 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs = {}
# Feature extraction.
backbone_features = self.backbone(images)
if self.decoder:
features = self.decoder(backbone_features)
else:
features = backbone_features
(backbone_features,
decoder_features) = self._get_backbone_and_decoder_features(images)
# Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(features)
rpn_scores, rpn_boxes = self.rpn_head(decoder_features)
model_outputs.update({
'backbone_features': backbone_features,
'decoder_features': features,
'decoder_features': decoder_features,
'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores
})
......@@ -219,7 +225,7 @@ class MaskRCNNModel(tf.keras.Model):
(class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices,
current_rois) = self._run_frcnn_head(
features=features,
features=decoder_features,
rois=current_rois,
gt_boxes=gt_boxes,
gt_classes=gt_classes,
......@@ -270,7 +276,6 @@ class MaskRCNNModel(tf.keras.Model):
'matched_gt_boxes': matched_gt_boxes,
'matched_gt_indices': matched_gt_indices,
'matched_gt_classes': matched_gt_classes,
'features': features,
'current_rois': current_rois,
}
return (model_outputs, intermediate_outputs)
......@@ -302,19 +307,16 @@ class MaskRCNNModel(tf.keras.Model):
current_rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes']
# Mask RoI align.
mask_roi_features = self.mask_roi_aligner(features, current_rois)
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
mask_logits, mask_probs = self._features_to_mask_outputs(
features, current_rois, roi_classes)
if training:
model_outputs.update({
'mask_outputs': raw_masks,
'mask_outputs': mask_logits,
})
else:
model_outputs.update({
'detection_masks': tf.math.sigmoid(raw_masks),
'detection_masks': mask_probs,
})
return model_outputs
......@@ -395,6 +397,15 @@ class MaskRCNNModel(tf.keras.Model):
return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices, rois)
def _features_to_mask_outputs(self, features, rois, roi_classes):
# Mask RoI align.
mask_roi_features = self.mask_roi_aligner(features, rois)
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
return raw_masks, tf.nn.sigmoid(raw_masks)
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
......
......@@ -43,10 +43,11 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
[batch_size, num_boxes, output_size, output_size, num_filters].
"""
batch_size, num_boxes, output_size, _, num_filters = (
features.get_shape().as_list())
if batch_size is None:
batch_size = tf.shape(features)[0]
features_shape = tf.shape(features)
batch_size, num_boxes, output_size, num_filters = (
features_shape[0], features_shape[1], features_shape[2],
features_shape[4])
output_size = output_size // 2
kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1])
kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2])
......@@ -88,7 +89,8 @@ def _compute_grid_positions(boxes, boundaries, output_size, sample_offset):
box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
"""
batch_size, num_boxes, _ = boxes.get_shape().as_list()
boxes_shape = tf.shape(boxes)
batch_size, num_boxes = boxes_shape[0], boxes_shape[1]
if batch_size is None:
batch_size = tf.shape(boxes)[0]
box_grid_x = []
......@@ -161,11 +163,12 @@ def multilevel_crop_and_resize(features,
levels = list(features.keys())
min_level = int(min(levels))
max_level = int(max(levels))
features_shape = tf.shape(features[str(min_level)])
batch_size, max_feature_height, max_feature_width, num_filters = (
features[str(min_level)].get_shape().as_list())
if batch_size is None:
batch_size = tf.shape(features[str(min_level)])[0]
_, num_boxes, _ = boxes.get_shape().as_list()
features_shape[0], features_shape[1], features_shape[2],
features_shape[3])
num_boxes = tf.shape(boxes)[1]
# Stack feature pyramid into a features_all of shape
# [batch_size, levels, height, width, num_filters].
......
......@@ -131,7 +131,7 @@ class DeepMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
model_mask_outputs = self._call_mask_outputs(
model_box_outputs=model_outputs,
features=intermediate_outputs['features'],
features=model_outputs['decoder_features'],
current_rois=intermediate_outputs['current_rois'],
matched_gt_indices=intermediate_outputs['matched_gt_indices'],
matched_gt_boxes=intermediate_outputs['matched_gt_boxes'],
......
......@@ -15,6 +15,7 @@
# Lint as: python3
"""Detection input and model functions for serving/inference."""
from typing import Mapping, Text
import tensorflow as tf
from official.vision.beta import configs
......@@ -78,13 +79,17 @@ class DetectionModule(export_base.ExportModule):
return image, anchor_boxes, image_info
def serve(self, images: tf.Tensor):
"""Cast image to float and run inference.
def preprocess(self, images: tf.Tensor) -> (
tf.Tensor, Mapping[Text, tf.Tensor], tf.Tensor):
"""Preprocess inputs to be suitable for the model.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
images: The images tensor.
Returns:
Tensor holding detection output logits.
images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
"""
model_params = self.params.task.model
with tf.device('cpu:0'):
......@@ -117,6 +122,18 @@ class DetectionModule(export_base.ExportModule):
image_info_spec),
parallel_iterations=32))
return images, anchor_boxes, image_info
def serve(self, images: tf.Tensor):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
images, anchor_boxes, image_info = self.preprocess(images)
input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment