Commit cd89148f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 344355705
parent 7042ff27
...@@ -107,6 +107,23 @@ def rpn_head_generator(params): ...@@ -107,6 +107,23 @@ def rpn_head_generator(params):
norm_activation=norm_activation_generator(params.norm_activation)) norm_activation=norm_activation_generator(params.norm_activation))
def oln_rpn_head_generator(params):
"""Generator function for OLN-proposal (OLN-RPN) head architecture."""
head_params = params.rpn_head
anchors_per_location = params.anchor.num_scales * len(
params.anchor.aspect_ratios)
return heads.OlnRpnHead(
params.architecture.min_level,
params.architecture.max_level,
anchors_per_location,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def fast_rcnn_head_generator(params): def fast_rcnn_head_generator(params):
"""Generator function for Fast R-CNN head architecture.""" """Generator function for Fast R-CNN head architecture."""
head_params = params.frcnn_head head_params = params.frcnn_head
...@@ -122,6 +139,21 @@ def fast_rcnn_head_generator(params): ...@@ -122,6 +139,21 @@ def fast_rcnn_head_generator(params):
norm_activation=norm_activation_generator(params.norm_activation)) norm_activation=norm_activation_generator(params.norm_activation))
def oln_box_score_head_generator(params):
"""Generator function for Scoring Fast R-CNN head architecture."""
head_params = params.frcnn_head
return heads.OlnBoxScoreHead(
params.architecture.num_classes,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
head_params.num_fcs,
head_params.fc_dims,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def mask_rcnn_head_generator(params): def mask_rcnn_head_generator(params):
"""Generator function for Mask R-CNN head architecture.""" """Generator function for Mask R-CNN head architecture."""
head_params = params.mrcnn_head head_params = params.mrcnn_head
...@@ -136,6 +168,20 @@ def mask_rcnn_head_generator(params): ...@@ -136,6 +168,20 @@ def mask_rcnn_head_generator(params):
norm_activation=norm_activation_generator(params.norm_activation)) norm_activation=norm_activation_generator(params.norm_activation))
def oln_mask_score_head_generator(params):
"""Generator function for Scoring Mask R-CNN head architecture."""
head_params = params.mrcnn_head
return heads.OlnMaskScoreHead(
params.architecture.num_classes,
params.architecture.mask_target_size,
head_params.num_convs,
head_params.num_filters,
head_params.use_separable_conv,
params.norm_activation.activation,
head_params.use_batch_norm,
norm_activation=norm_activation_generator(params.norm_activation))
def shapeprior_head_generator(params): def shapeprior_head_generator(params):
"""Generator function for shape prior head architecture.""" """Generator function for shape prior head architecture."""
head_params = params.shapemask_head head_params = params.shapemask_head
......
...@@ -137,6 +137,130 @@ class RpnHead(tf.keras.layers.Layer): ...@@ -137,6 +137,130 @@ class RpnHead(tf.keras.layers.Layer):
return scores_outputs, box_outputs return scores_outputs, box_outputs
class OlnRpnHead(tf.keras.layers.Layer):
"""Region Proposal Network for Object Localization Network (OLN)."""
def __init__(
self,
min_level,
max_level,
anchors_per_location,
num_convs=2,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build Region Proposal Network head.
Args:
min_level: `int` number of minimum feature level.
max_level: `int` number of maximum feature level.
anchors_per_location: `int` number of number of anchors per pixel
location.
num_convs: `int` number that represents the number of the intermediate
conv layers before the prediction.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
self._min_level = min_level
self._max_level = max_level
self._anchors_per_location = anchors_per_location
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer())
self._rpn_conv = self._conv2d_op(
num_filters,
kernel_size=(3, 3),
strides=(1, 1),
activation=(None if self._use_batch_norm else self._activation_op),
padding='same',
name='rpn')
self._rpn_class_conv = self._conv2d_op(
anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-class')
self._rpn_box_conv = self._conv2d_op(
4 * anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-box-lrtb')
self._rpn_center_conv = self._conv2d_op(
anchors_per_location,
kernel_size=(1, 1),
strides=(1, 1),
padding='valid',
name='rpn-centerness')
self._norm_activations = {}
if self._use_batch_norm:
for level in range(self._min_level, self._max_level + 1):
self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
level)
def _shared_rpn_heads(self, features, anchors_per_location, level,
is_training):
"""Shared RPN heads."""
features = self._rpn_conv(features)
if self._use_batch_norm:
# The batch normalization layers are not shared between levels.
features = self._norm_activations[level](
features, is_training=is_training)
# Feature L2 normalization for training stability
features = tf.math.l2_normalize(
features,
axis=-1,
name='rpn-norm',)
# Proposal classification scores
scores = self._rpn_class_conv(features)
# Proposal bbox regression deltas
bboxes = self._rpn_box_conv(features)
# Proposal centerness scores
centers = self._rpn_center_conv(features)
return scores, bboxes, centers
def __call__(self, features, is_training=None):
scores_outputs = {}
box_outputs = {}
center_outputs = {}
with keras_utils.maybe_enter_backend_graph(), tf.name_scope('rpn_head'):
for level in range(self._min_level, self._max_level + 1):
scores_output, box_output, center_output = self._shared_rpn_heads(
features[level], self._anchors_per_location, level, is_training)
scores_outputs[level] = scores_output
box_outputs[level] = box_output
center_outputs[level] = center_output
return scores_outputs, box_outputs, center_outputs
class FastrcnnHead(tf.keras.layers.Layer): class FastrcnnHead(tf.keras.layers.Layer):
"""Fast R-CNN box head.""" """Fast R-CNN box head."""
...@@ -276,6 +400,151 @@ class FastrcnnHead(tf.keras.layers.Layer): ...@@ -276,6 +400,151 @@ class FastrcnnHead(tf.keras.layers.Layer):
return class_outputs, box_outputs return class_outputs, box_outputs
class OlnBoxScoreHead(tf.keras.layers.Layer):
"""Box head of Object Localization Network (OLN)."""
def __init__(
self,
num_classes,
num_convs=0,
num_filters=256,
use_separable_conv=False,
num_fcs=2,
fc_dims=1024,
activation='relu',
use_batch_norm=True,
norm_activation=nn_ops.norm_activation_builder(activation='relu')):
"""Initialize params to build OLN box head.
Args:
num_classes: a integer for the number of classes.
num_convs: `int` number that represents the number of the intermediate
conv layers before the FC layers.
num_filters: `int` number that represents the number of filters of the
intermediate conv layers.
use_separable_conv: `bool`, indicating whether the separable conv layers
is used.
num_fcs: `int` number that represents the number of FC layers before the
predictions.
fc_dims: `int` number that represents the number of dimension of the FC
layers.
activation: activation function. Support 'relu' and 'swish'.
use_batch_norm: 'bool', indicating whether batchnorm layers are added.
norm_activation: an operation that includes a normalization layer followed
by an optional activation layer.
"""
self._num_classes = num_classes
self._num_convs = num_convs
self._num_filters = num_filters
if use_separable_conv:
self._conv2d_op = functools.partial(
tf.keras.layers.SeparableConv2D,
depth_multiplier=1,
bias_initializer=tf.zeros_initializer())
else:
self._conv2d_op = functools.partial(
tf.keras.layers.Conv2D,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
bias_initializer=tf.zeros_initializer())
self._num_fcs = num_fcs
self._fc_dims = fc_dims
if activation == 'relu':
self._activation_op = tf.nn.relu
elif activation == 'swish':
self._activation_op = tf.nn.swish
else:
raise ValueError('Unsupported activation `{}`.'.format(activation))
self._use_batch_norm = use_batch_norm
self._norm_activation = norm_activation
self._conv_ops = []
self._conv_bn_ops = []
for i in range(self._num_convs):
self._conv_ops.append(
self._conv2d_op(
self._num_filters,
kernel_size=(3, 3),
strides=(1, 1),
padding='same',
dilation_rate=(1, 1),
activation=(None
if self._use_batch_norm else self._activation_op),
name='conv_{}'.format(i)))
if self._use_batch_norm:
self._conv_bn_ops.append(self._norm_activation())
self._fc_ops = []
self._fc_bn_ops = []
for i in range(self._num_fcs):
self._fc_ops.append(
tf.keras.layers.Dense(
units=self._fc_dims,
activation=(None
if self._use_batch_norm else self._activation_op),
name='fc{}'.format(i)))
if self._use_batch_norm:
self._fc_bn_ops.append(self._norm_activation(fused=False))
self._class_predict = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
name='class-predict')
self._box_predict = tf.keras.layers.Dense(
self._num_classes * 4,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
bias_initializer=tf.zeros_initializer(),
name='box-predict')
self._score_predict = tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
name='score-predict')
def __call__(self, roi_features, is_training=None):
"""Box and class branches for the Mask-RCNN model.
Args:
roi_features: A ROI feature tensor of shape [batch_size, num_rois,
height_l, width_l, num_filters].
is_training: `boolean`, if True if model is in training mode.
Returns:
class_outputs: a tensor with a shape of
[batch_size, num_rois, num_classes], representing the class predictions.
box_outputs: a tensor with a shape of
[batch_size, num_rois, num_classes * 4], representing the box
predictions.
"""
with keras_utils.maybe_enter_backend_graph(), tf.name_scope(
'fast_rcnn_head'):
# reshape inputs beofre FC.
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
net = tf.reshape(roi_features, [-1, height, width, filters])
for i in range(self._num_convs):
net = self._conv_ops[i](net)
if self._use_batch_norm:
net = self._conv_bn_ops[i](net, is_training=is_training)
filters = self._num_filters if self._num_convs > 0 else filters
net = tf.reshape(net, [-1, num_rois, height * width * filters])
for i in range(self._num_fcs):
net = self._fc_ops[i](net)
if self._use_batch_norm:
net = self._fc_bn_ops[i](net, is_training=is_training)
class_outputs = self._class_predict(net)
box_outputs = self._box_predict(net)
score_outputs = self._score_predict(net)
return class_outputs, box_outputs, score_outputs
class MaskrcnnHead(tf.keras.layers.Layer): class MaskrcnnHead(tf.keras.layers.Layer):
"""Mask R-CNN head.""" """Mask R-CNN head."""
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from official.vision.detection.modeling import maskrcnn_model from official.vision.detection.modeling import maskrcnn_model
from official.vision.detection.modeling import olnmask_model
from official.vision.detection.modeling import retinanet_model from official.vision.detection.modeling import retinanet_model
from official.vision.detection.modeling import shapemask_model from official.vision.detection.modeling import shapemask_model
...@@ -26,6 +27,8 @@ def model_generator(params): ...@@ -26,6 +27,8 @@ def model_generator(params):
model_fn = retinanet_model.RetinanetModel(params) model_fn = retinanet_model.RetinanetModel(params)
elif params.type == 'mask_rcnn': elif params.type == 'mask_rcnn':
model_fn = maskrcnn_model.MaskrcnnModel(params) model_fn = maskrcnn_model.MaskrcnnModel(params)
elif params.type == 'olnmask':
model_fn = olnmask_model.OlnMaskModel(params)
elif params.type == 'shapemask': elif params.type == 'shapemask':
model_fn = shapemask_model.ShapeMaskModel(params) model_fn = shapemask_model.ShapeMaskModel(params)
else: else:
......
...@@ -195,6 +195,150 @@ class RpnBoxLoss(object): ...@@ -195,6 +195,150 @@ class RpnBoxLoss(object):
return box_loss return box_loss
class OlnRpnCenterLoss(object):
"""Object Localization Network RPN centerness regression loss function."""
def __init__(self):
self._l1_loss = tf.keras.losses.MeanAbsoluteError(
reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, center_outputs, labels):
"""Computes total RPN centerness regression loss.
Computes total RPN centerness score regression loss from all levels.
Args:
center_outputs: an OrderDict with keys representing levels and values
representing anchor centerness regression targets in
[batch_size, height, width, num_anchors * 4].
labels: the dictionary that returned from dataloader that includes
groundturth targets.
Returns:
rpn_center_loss: a scalar tensor representing total centerness regression
loss.
"""
with tf.name_scope('rpn_loss'):
# Normalizer.
levels = sorted(center_outputs.keys())
num_valid = 0
# 0<pos<1, neg=0, ign=-1
for level in levels:
num_valid += tf.reduce_sum(tf.cast(
tf.greater(labels[level], -1.0), tf.float32)) # in and out of box
num_valid += 1e-12
# Centerness loss over multi levels.
center_losses = []
for level in levels:
center_losses.append(
self._rpn_center_l1_loss(
center_outputs[level], labels[level],
normalizer=num_valid))
# Sum per level losses to total loss.
return tf.add_n(center_losses)
def _rpn_center_l1_loss(self, center_outputs, center_targets,
normalizer=1.0):
"""Computes centerness regression loss."""
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_center_loss'):
# mask = tf.greater(center_targets, 0.0) # inside box only.
mask = tf.greater(center_targets, -1.0) # in and out of box.
center_targets = tf.maximum(center_targets, tf.zeros_like(center_targets))
center_outputs = tf.sigmoid(center_outputs)
center_targets = tf.expand_dims(center_targets, -1)
center_outputs = tf.expand_dims(center_outputs, -1)
mask = tf.cast(mask, dtype=tf.float32)
center_loss = self._l1_loss(center_targets, center_outputs,
sample_weight=mask)
center_loss /= normalizer
return center_loss
class OlnRpnIoULoss(object):
"""Object Localization Network RPN box-lrtb regression iou loss function."""
def __call__(self, box_outputs, labels, center_targets):
"""Computes total RPN detection loss.
Computes total RPN box regression loss from all levels.
Args:
box_outputs: an OrderDict with keys representing levels and values
representing box regression targets in
[batch_size, height, width, num_anchors * 4].
last channel: (left, right, top, bottom).
labels: the dictionary that returned from dataloader that includes
groundturth targets (left, right, top, bottom).
center_targets: valid_target mask.
Returns:
rpn_iou_loss: a scalar tensor representing total box regression loss.
"""
with tf.name_scope('rpn_loss'):
# Normalizer.
levels = sorted(box_outputs.keys())
normalizer = 0.
for level in levels:
# center_targets pos>0, neg=0, ign=-1.
mask_ = tf.cast(tf.logical_and(
tf.greater(center_targets[level][..., 0], 0.0),
tf.greater(tf.reduce_min(labels[level], -1), 0.0)), tf.float32)
normalizer += tf.reduce_sum(mask_)
normalizer += 1e-8
# iou_loss over multi levels.
iou_losses = []
for level in levels:
iou_losses.append(
self._rpn_iou_loss(
box_outputs[level], labels[level],
center_weight=center_targets[level][..., 0],
normalizer=normalizer))
# Sum per level losses to total loss.
return tf.add_n(iou_losses)
def _rpn_iou_loss(self, box_outputs, box_targets,
center_weight=None, normalizer=1.0):
"""Computes box regression loss."""
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
with tf.name_scope('rpn_iou_loss'):
mask = tf.logical_and(
tf.greater(center_weight, 0.0),
tf.greater(tf.reduce_min(box_targets, -1), 0.0))
pred_left = box_outputs[..., 0]
pred_right = box_outputs[..., 1]
pred_top = box_outputs[..., 2]
pred_bottom = box_outputs[..., 3]
gt_left = box_targets[..., 0]
gt_right = box_targets[..., 1]
gt_top = box_targets[..., 2]
gt_bottom = box_targets[..., 3]
inter_width = (tf.minimum(pred_left, gt_left) +
tf.minimum(pred_right, gt_right))
inter_height = (tf.minimum(pred_top, gt_top) +
tf.minimum(pred_bottom, gt_bottom))
inter_area = inter_width * inter_height
union_area = ((pred_left + pred_right) * (pred_top + pred_bottom) +
(gt_left + gt_right) * (gt_top + gt_bottom) -
inter_area)
iou = inter_area / (union_area + 1e-8)
iou = tf.where(mask, iou, tf.ones_like(iou))
mask_ = tf.cast(mask, tf.float32)
iou = tf.clip_by_value(iou, clip_value_min=1e-8, clip_value_max=1.0)
neg_log_iou = -tf.math.log(iou)
iou_loss = tf.reduce_sum(neg_log_iou * mask_)
iou_loss /= normalizer
return iou_loss
class FastrcnnClassLoss(object): class FastrcnnClassLoss(object):
"""Fast R-CNN classification loss function.""" """Fast R-CNN classification loss function."""
...@@ -317,6 +461,47 @@ class FastrcnnBoxLoss(object): ...@@ -317,6 +461,47 @@ class FastrcnnBoxLoss(object):
return box_loss return box_loss
class OlnBoxScoreLoss(object):
"""Object Localization Network Box-Iou scoring function."""
def __init__(self, params):
self._ignore_threshold = params.ignore_threshold
self._l1_loss = tf.keras.losses.MeanAbsoluteError(
reduction=tf.keras.losses.Reduction.SUM)
def __call__(self, score_outputs, score_targets):
"""Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
This function implements the classification loss of the Fast-RCNN.
The classification loss is softmax on all RoIs.
Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
Args:
score_outputs: a float tensor representing the class prediction for each box
with a shape of [batch_size, num_boxes, num_classes].
score_targets: a float tensor representing the class label for each box
with a shape of [batch_size, num_boxes].
Returns:
a scalar tensor representing total score loss.
"""
with tf.name_scope('fast_rcnn_loss'):
score_outputs = tf.squeeze(score_outputs, -1)
mask = tf.greater(score_targets, self._ignore_threshold)
num_valid = tf.reduce_sum(tf.cast(mask, tf.float32))
score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets))
score_outputs = tf.sigmoid(score_outputs)
score_targets = tf.expand_dims(score_targets, -1)
score_outputs = tf.expand_dims(score_outputs, -1)
mask = tf.cast(mask, dtype=tf.float32)
score_loss = self._l1_loss(score_targets, score_outputs,
sample_weight=mask)
score_loss /= num_valid
return score_loss
class MaskrcnnLoss(object): class MaskrcnnLoss(object):
"""Mask R-CNN instance segmentation mask loss function.""" """Mask R-CNN instance segmentation mask loss function."""
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model defination for the Object Localization Network (OLN) Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.vision.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys
from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.maskrcnn_model import MaskrcnnModel
from official.vision.detection.ops import postprocess_ops
from official.vision.detection.ops import roi_ops
from official.vision.detection.ops import spatial_transform_ops
from official.vision.detection.ops import target_ops
from official.vision.detection.utils import box_utils
class OlnMaskModel(MaskrcnnModel):
"""OLN-Mask model function."""
def __init__(self, params):
super(OlnMaskModel, self).__init__(params)
self._params = params
# Different heads and layers.
self._include_rpn_class = params.architecture.include_rpn_class
self._include_mask = params.architecture.include_mask
self._include_frcnn_class = params.architecture.include_frcnn_class
self._include_frcnn_box = params.architecture.include_frcnn_box
self._include_centerness = params.rpn_head.has_centerness
self._include_box_score = (params.frcnn_head.has_scoring and
params.architecture.include_frcnn_box)
self._include_mask_score = (params.mrcnn_head.has_scoring and
params.architecture.include_mask)
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._rpn_head_fn = factory.rpn_head_generator(params)
if self._include_centerness:
self._rpn_head_fn = factory.oln_rpn_head_generator(params)
else:
self._rpn_head_fn = factory.rpn_head_generator(params)
self._generate_rois_fn = roi_ops.OlnROIGenerator(params.roi_proposal)
self._sample_rois_fn = target_ops.ROIScoreSampler(params.roi_sampling)
self._sample_masks_fn = target_ops.MaskSampler(
params.architecture.mask_target_size,
params.mask_sampling.num_mask_samples_per_image)
if self._include_box_score:
self._frcnn_head_fn = factory.oln_box_score_head_generator(params)
else:
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
if self._include_mask:
if self._include_mask_score:
self._mrcnn_head_fn = factory.oln_mask_score_head_generator(params)
else:
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
# Loss function.
self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
if self._include_centerness:
self._rpn_iou_loss_fn = losses.OlnRpnIoULoss()
self._rpn_center_loss_fn = losses.OlnRpnCenterLoss()
self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
if self._include_box_score:
self._frcnn_box_score_loss_fn = losses.OlnBoxScoreLoss(
params.frcnn_box_score_loss)
if self._include_mask:
self._mask_loss_fn = losses.MaskrcnnLoss()
self._generate_detections_fn = postprocess_ops.OlnDetectionGenerator(
params.postprocess)
self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supportted.'
def build_outputs(self, inputs, mode):
is_training = mode == mode_keys.TRAIN
model_outputs = {}
image = inputs['image']
_, image_height, image_width, _ = image.get_shape().as_list()
backbone_features = self._backbone_fn(image, is_training)
fpn_features = self._fpn_fn(backbone_features, is_training)
# rpn_centerness.
if self._include_centerness:
rpn_score_outputs, rpn_box_outputs, rpn_center_outputs = (
self._rpn_head_fn(fpn_features, is_training))
model_outputs.update({
'rpn_center_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_center_outputs),
})
object_scores = rpn_center_outputs
else:
rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
fpn_features, is_training)
object_scores = None
model_outputs.update({
'rpn_score_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_score_outputs),
'rpn_box_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
rpn_box_outputs),
})
input_anchor = anchor.Anchor(self._params.architecture.min_level,
self._params.architecture.max_level,
self._params.anchor.num_scales,
self._params.anchor.aspect_ratios,
self._params.anchor.anchor_size,
(image_height, image_width))
rpn_rois, rpn_roi_scores = self._generate_rois_fn(
rpn_box_outputs,
rpn_score_outputs,
input_anchor.multilevel_boxes,
inputs['image_info'][:, 1, :],
is_training,
is_box_lrtb=self._include_centerness,
object_scores=object_scores,
)
if (not self._include_frcnn_class and
not self._include_frcnn_box and
not self._include_mask):
# if not is_training:
# For direct RPN detection,
# use dummy box_outputs = (dy,dx,dh,dw = 0,0,0,0)
box_outputs = tf.zeros_like(rpn_rois)
box_outputs = tf.concat([box_outputs, box_outputs], -1)
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs, rpn_roi_scores, rpn_rois,
inputs['image_info'][:, 1:2, :],
is_single_fg_score=True, # if no_background, no softmax is applied.
keep_nms=True)
model_outputs.update({
'num_detections': valid_detections,
'detection_boxes': boxes,
'detection_classes': classes,
'detection_scores': scores,
})
return model_outputs
# ---- OLN-Proposal finishes here. ----
if is_training:
rpn_rois = tf.stop_gradient(rpn_rois)
rpn_roi_scores = tf.stop_gradient(rpn_roi_scores)
# Sample proposals.
(rpn_rois, rpn_roi_scores, matched_gt_boxes, matched_gt_classes,
matched_gt_indices) = (
self._sample_rois_fn(rpn_rois, rpn_roi_scores, inputs['gt_boxes'],
inputs['gt_classes']))
# Create bounding box training targets.
box_targets = box_utils.encode_boxes(
matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]), tf.zeros_like(box_targets), box_targets)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
# Create Box-IoU targets. {
box_ious = box_utils.bbox_overlap(
rpn_rois, inputs['gt_boxes'])
matched_box_ious = tf.reduce_max(box_ious, 2)
model_outputs.update({
'box_iou_targets': matched_box_ious,}) # }
roi_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, rpn_rois, output_size=7)
if not self._include_box_score:
class_outputs, box_outputs = self._frcnn_head_fn(
roi_features, is_training)
else:
class_outputs, box_outputs, score_outputs = self._frcnn_head_fn(
roi_features, is_training)
model_outputs.update({
'box_score_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
score_outputs),})
model_outputs.update({
'class_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
class_outputs),
'box_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
box_outputs),
})
# Add this output to train to make the checkpoint loadable in predict mode.
# If we skip it in train mode, the heads will be out-of-order and checkpoint
# loading will fail.
if not self._include_frcnn_box:
box_outputs = tf.zeros_like(box_outputs) # dummy zeros.
if self._include_box_score:
score_outputs = tf.cast(tf.squeeze(score_outputs, -1),
rpn_roi_scores.dtype)
# box-score = (rpn-centerness * box-iou)^(1/2)
# TR: rpn_roi_scores: b,1000, score_outputs: b,512
# TS: rpn_roi_scores: b,1000, score_outputs: b,1000
box_scores = tf.pow(
rpn_roi_scores * tf.sigmoid(score_outputs), 1/2.)
if not self._include_frcnn_class:
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs,
box_scores,
rpn_rois,
inputs['image_info'][:, 1:2, :],
is_single_fg_score=True,
keep_nms=True,)
else:
boxes, scores, classes, valid_detections = self._generate_detections_fn(
box_outputs, class_outputs, rpn_rois,
inputs['image_info'][:, 1:2, :],
keep_nms=True,)
model_outputs.update({
'num_detections': valid_detections,
'detection_boxes': boxes,
'detection_classes': classes,
'detection_scores': scores,
})
# ---- OLN-Box finishes here. ----
if not self._include_mask:
return model_outputs
if is_training:
rpn_rois, classes, mask_targets = self._sample_masks_fn(
rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
inputs['gt_masks'])
mask_targets = tf.stop_gradient(mask_targets)
classes = tf.cast(classes, dtype=tf.int32)
model_outputs.update({
'mask_targets': mask_targets,
'sampled_class_targets': classes,
})
else:
rpn_rois = boxes
classes = tf.cast(classes, dtype=tf.int32)
mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
fpn_features, rpn_rois, output_size=14)
mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)
if is_training:
model_outputs.update({
'mask_outputs':
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
mask_outputs),
})
else:
model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)})
return model_outputs
def build_loss_fn(self):
if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs):
if self._include_rpn_class:
rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
labels['rpn_score_targets'])
else:
rpn_score_loss = 0.0
if self._include_centerness:
rpn_center_loss = self._rpn_center_loss_fn(
outputs['rpn_center_outputs'], labels['rpn_center_targets'])
rpn_box_loss = self._rpn_iou_loss_fn(
outputs['rpn_box_outputs'], labels['rpn_box_targets'],
labels['rpn_center_targets'])
else:
rpn_center_loss = 0.0
rpn_box_loss = self._rpn_box_loss_fn(
outputs['rpn_box_outputs'], labels['rpn_box_targets'])
if self._include_frcnn_class:
frcnn_class_loss = self._frcnn_class_loss_fn(
outputs['class_outputs'], outputs['class_targets'])
else:
frcnn_class_loss = 0.0
if self._include_frcnn_box:
frcnn_box_loss = self._frcnn_box_loss_fn(
outputs['box_outputs'], outputs['class_targets'],
outputs['box_targets'])
else:
frcnn_box_loss = 0.0
if self._include_box_score:
box_score_loss = self._frcnn_box_score_loss_fn(
outputs['box_score_outputs'], outputs['box_iou_targets'])
else:
box_score_loss = 0.0
if self._include_mask:
mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
outputs['mask_targets'],
outputs['sampled_class_targets'])
else:
mask_loss = 0.0
model_loss = (
rpn_score_loss + rpn_box_loss + rpn_center_loss +
frcnn_class_loss + frcnn_box_loss + box_score_loss +
mask_loss)
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
'loss': total_loss,
'fast_rcnn_class_loss': frcnn_class_loss,
'fast_rcnn_box_loss': frcnn_box_loss,
'fast_rcnn_box_score_loss': box_score_loss,
'mask_loss': mask_loss,
'model_loss': model_loss,
'l2_regularization_loss': l2_regularization_loss,
'rpn_score_loss': rpn_score_loss,
'rpn_box_loss': rpn_box_loss,
'rpn_center_loss': rpn_center_loss,
}
return _total_loss_fn
def build_input_layers(self, params, mode):
is_training = mode == mode_keys.TRAIN
input_shape = (
params.olnmask_parser.output_size +
[params.olnmask_parser.num_channels])
if is_training:
batch_size = params.train.batch_size
input_layer = {
'image':
tf.keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf.keras.layers.Input(
shape=[4, 2],
batch_size=batch_size,
name='image_info',
),
'gt_boxes':
tf.keras.layers.Input(
shape=[params.olnmask_parser.max_num_instances, 4],
batch_size=batch_size,
name='gt_boxes'),
'gt_classes':
tf.keras.layers.Input(
shape=[params.olnmask_parser.max_num_instances],
batch_size=batch_size,
name='gt_classes',
dtype=tf.int64),
}
if self._include_mask:
input_layer['gt_masks'] = tf.keras.layers.Input(
shape=[
params.olnmask_parser.max_num_instances,
params.olnmask_parser.mask_crop_size,
params.olnmask_parser.mask_crop_size
],
batch_size=batch_size,
name='gt_masks')
else:
batch_size = params.eval.batch_size
input_layer = {
'image':
tf.keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf.keras.layers.Input(
shape=[4, 2],
batch_size=batch_size,
name='image_info',
),
}
return input_layer
def build_model(self, params, mode):
if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode)
with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model(
inputs=input_layers, outputs=outputs, name='olnmask')
assert model is not None, 'Fail to build tf.keras.Model.'
model.optimizer = self.build_optimizer()
self._keras_model = model
return self._keras_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment