Commit d1a3fa6a authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370677512
parent 834c0454
...@@ -100,6 +100,10 @@ class DetectionHead(hyperparams.Config): ...@@ -100,6 +100,10 @@ class DetectionHead(hyperparams.Config):
use_separable_conv: bool = False use_separable_conv: bool = False
num_fcs: int = 1 num_fcs: int = 1
fc_dims: int = 1024 fc_dims: int = 1024
class_agnostic_bbox_pred: bool = False # Has to be True for Cascade RCNN.
# If additional IoUs are passed in 'cascade_iou_thresholds'
# then ensemble the class probabilities from all heads.
cascade_class_ensemble: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -125,6 +129,9 @@ class ROISampler(hyperparams.Config): ...@@ -125,6 +129,9 @@ class ROISampler(hyperparams.Config):
foreground_iou_threshold: float = 0.5 foreground_iou_threshold: float = 0.5
background_iou_high_threshold: float = 0.5 background_iou_high_threshold: float = 0.5
background_iou_low_threshold: float = 0.0 background_iou_low_threshold: float = 0.0
# IoU thresholds for additional FRCNN heads in Cascade mode.
# `foreground_iou_threshold` is the first threshold.
cascade_iou_thresholds: Optional[List[float]] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -282,7 +289,6 @@ def fasterrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -282,7 +289,6 @@ def fasterrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
return config return config
...@@ -292,6 +298,68 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -292,6 +298,68 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
steps_per_epoch = 500 steps_per_epoch = 500
coco_val_samples = 5000 coco_val_samples = 5000
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=MaskRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=MaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], include_mask=True),
losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=64,
parser=Parser(
aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=8)),
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // 8,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [15000, 20000],
'values': [0.12, 0.012, 0.0012],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('cascadercnn_resnetfpn_coco')
def cascadercnn_resnetfpn_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Cascade R-CNN."""
steps_per_epoch = 500
coco_val_samples = 5000
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=MaskRCNNTask( task=MaskRCNNTask(
...@@ -302,7 +370,10 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -302,7 +370,10 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
model=MaskRCNN( model=MaskRCNN(
num_classes=91, num_classes=91,
input_size=[1024, 1024, 3], input_size=[1024, 1024, 3],
include_mask=True), include_mask=True,
roi_sampler=ROISampler(cascade_iou_thresholds=[0.6, 0.7]),
detection_head=DetectionHead(
class_agnostic_bbox_pred=True, cascade_class_ensemble=True)),
losses=Losses(l2_weight_decay=0.00004), losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'), input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
...@@ -347,7 +418,6 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -347,7 +418,6 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
return config return config
......
...@@ -173,13 +173,21 @@ class FastrcnnClassLoss(object): ...@@ -173,13 +173,21 @@ class FastrcnnClassLoss(object):
class FastrcnnBoxLoss(object): class FastrcnnBoxLoss(object):
"""Fast R-CNN box regression loss function.""" """Fast R-CNN box regression loss function."""
def __init__(self, huber_loss_delta: float): def __init__(self,
# The delta is typically around the mean value of regression target. huber_loss_delta: float,
# for instances, the regression targets of 512x512 input with 6 anchors on class_agnostic_bbox_pred: bool = False):
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2]. """Initiate Faster RCNN box loss.
Args:
huber_loss_delta: the delta is typically around the mean value of
regression target. for instances, the regression targets of 512x512
input with 6 anchors on P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
class_agnostic_bbox_pred: if True, class agnostic bounding box prediction
is performed.
"""
self._huber_loss = tf.keras.losses.Huber( self._huber_loss = tf.keras.losses.Huber(
delta=huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM) delta=huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
self._class_agnostic_bbox_pred = class_agnostic_bbox_pred
def __call__(self, box_outputs, class_targets, box_targets): def __call__(self, box_outputs, class_targets, box_targets):
"""Computes the box loss (Fast-RCNN branch) of Mask-RCNN. """Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
...@@ -207,32 +215,35 @@ class FastrcnnBoxLoss(object): ...@@ -207,32 +215,35 @@ class FastrcnnBoxLoss(object):
""" """
with tf.name_scope('fast_rcnn_loss'): with tf.name_scope('fast_rcnn_loss'):
class_targets = tf.cast(class_targets, dtype=tf.int32) class_targets = tf.cast(class_targets, dtype=tf.int32)
if not self._class_agnostic_bbox_pred:
# Selects the box from `box_outputs` based on `class_targets`, with which box_outputs = self._assign_class_targets(box_outputs, class_targets)
# the box has the maximum overlap.
(batch_size, num_rois,
num_class_specific_boxes) = box_outputs.get_shape().as_list()
num_classes = num_class_specific_boxes // 4
box_outputs = tf.reshape(box_outputs,
[batch_size, num_rois, num_classes, 4])
box_indices = tf.reshape(
class_targets + tf.tile(
tf.expand_dims(
tf.range(batch_size) * num_rois * num_classes, 1),
[1, num_rois]) + tf.tile(
tf.expand_dims(tf.range(num_rois) * num_classes, 0),
[batch_size, 1]), [-1])
box_outputs = tf.matmul(
tf.one_hot(
box_indices,
batch_size * num_rois * num_classes,
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets) return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
def _assign_class_targets(self, box_outputs, class_targets):
"""Selects the box from `box_outputs` based on `class_targets`, with which the box has the maximum overlap."""
(batch_size, num_rois,
num_class_specific_boxes) = box_outputs.get_shape().as_list()
num_classes = num_class_specific_boxes // 4
box_outputs = tf.reshape(box_outputs,
[batch_size, num_rois, num_classes, 4])
box_indices = tf.reshape(
class_targets + tf.tile(
tf.expand_dims(tf.range(batch_size) * num_rois * num_classes, 1),
[1, num_rois]) + tf.tile(
tf.expand_dims(tf.range(num_rois) * num_classes, 0),
[batch_size, 1]), [-1])
box_outputs = tf.matmul(
tf.one_hot(
box_indices,
batch_size * num_rois * num_classes,
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
return box_outputs
def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets, def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
normalizer=1.0): normalizer=1.0):
"""Computes box regression loss.""" """Computes box regression loss."""
...@@ -299,4 +310,3 @@ class MaskrcnnLoss(object): ...@@ -299,4 +310,3 @@ class MaskrcnnLoss(object):
# The loss is normalized by the number of 1's in weights and # The loss is normalized by the number of 1's in weights and
# + 0.01 is used to avoid division by zero. # + 0.01 is used to avoid division by zero.
return mask_loss / (tf.reduce_sum(weights) + 0.01) return mask_loss / (tf.reduce_sum(weights) + 0.01)
...@@ -109,6 +109,7 @@ def build_maskrcnn( ...@@ -109,6 +109,7 @@ def build_maskrcnn(
use_separable_conv=detection_head_config.use_separable_conv, use_separable_conv=detection_head_config.use_separable_conv,
num_fcs=detection_head_config.num_fcs, num_fcs=detection_head_config.num_fcs,
fc_dims=detection_head_config.fc_dims, fc_dims=detection_head_config.fc_dims,
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
...@@ -131,6 +132,7 @@ def build_maskrcnn( ...@@ -131,6 +132,7 @@ def build_maskrcnn(
test_num_proposals=roi_generator_config.test_num_proposals, test_num_proposals=roi_generator_config.test_num_proposals,
use_batched_nms=roi_generator_config.use_batched_nms) use_batched_nms=roi_generator_config.use_batched_nms)
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler( roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=roi_sampler_config.mix_gt_boxes, mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
num_sampled_rois=roi_sampler_config.num_sampled_rois, num_sampled_rois=roi_sampler_config.num_sampled_rois,
...@@ -140,6 +142,18 @@ def build_maskrcnn( ...@@ -140,6 +142,18 @@ def build_maskrcnn(
roi_sampler_config.background_iou_high_threshold), roi_sampler_config.background_iou_high_threshold),
background_iou_low_threshold=( background_iou_low_threshold=(
roi_sampler_config.background_iou_low_threshold)) roi_sampler_config.background_iou_low_threshold))
roi_sampler_cascade.append(roi_sampler_obj)
# Initialize addtional roi simplers for cascade heads.
if roi_sampler_config.cascade_iou_thresholds:
for iou in roi_sampler_config.cascade_iou_thresholds:
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=False,
num_sampled_rois=roi_sampler_config.num_sampled_rois,
foreground_iou_threshold=iou,
background_iou_high_threshold=iou,
background_iou_low_threshold=0.0,
skip_subsampling=True)
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner( roi_aligner_obj = roi_aligner.MultilevelROIAligner(
crop_size=roi_aligner_config.crop_size, crop_size=roi_aligner_config.crop_size,
...@@ -186,12 +200,14 @@ def build_maskrcnn( ...@@ -186,12 +200,14 @@ def build_maskrcnn(
rpn_head=rpn_head, rpn_head=rpn_head,
detection_head=detection_head, detection_head=detection_head,
roi_generator=roi_generator_obj, roi_generator=roi_generator_obj,
roi_sampler=roi_sampler_obj, roi_sampler=roi_sampler_cascade,
roi_aligner=roi_aligner_obj, roi_aligner=roi_aligner_obj,
detection_generator=detection_generator_obj, detection_generator=detection_generator_obj,
mask_head=mask_head, mask_head=mask_head,
mask_sampler=mask_sampler_obj, mask_sampler=mask_sampler_obj,
mask_roi_aligner=mask_roi_aligner_obj) mask_roi_aligner=mask_roi_aligner_obj,
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
cascade_class_ensemble=detection_head_config.cascade_class_ensemble)
return model return model
......
...@@ -33,6 +33,7 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -33,6 +33,7 @@ class DetectionHead(tf.keras.layers.Layer):
use_separable_conv: bool = False, use_separable_conv: bool = False,
num_fcs: int = 2, num_fcs: int = 2,
fc_dims: int = 1024, fc_dims: int = 1024,
class_agnostic_bbox_pred: bool = False,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
...@@ -54,6 +55,8 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -54,6 +55,8 @@ class DetectionHead(tf.keras.layers.Layer):
the predictions. the predictions.
fc_dims: An `int` number that represents the number of dimension of the FC fc_dims: An `int` number that represents the number of dimension of the FC
layers. layers.
class_agnostic_bbox_pred: `bool`, indicating whether bboxes should be
predicted for every class or not.
activation: A `str` that indicates which activation is used, e.g. 'relu', activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc. 'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch use_sync_bn: A `bool` that indicates whether to use synchronized batch
...@@ -73,6 +76,7 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -73,6 +76,7 @@ class DetectionHead(tf.keras.layers.Layer):
'use_separable_conv': use_separable_conv, 'use_separable_conv': use_separable_conv,
'num_fcs': num_fcs, 'num_fcs': num_fcs,
'fc_dims': fc_dims, 'fc_dims': fc_dims,
'class_agnostic_bbox_pred': class_agnostic_bbox_pred,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum, 'norm_momentum': norm_momentum,
...@@ -155,8 +159,11 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -155,8 +159,11 @@ class DetectionHead(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'], bias_regularizer=self._config_dict['bias_regularizer'],
name='detection-scores') name='detection-scores')
num_box_outputs = (4 if self._config_dict['class_agnostic_bbox_pred'] else
self._config_dict['num_classes'] * 4)
self._box_regressor = tf.keras.layers.Dense( self._box_regressor = tf.keras.layers.Dense(
units=self._config_dict['num_classes'] * 4, units=num_box_outputs,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
bias_initializer=tf.zeros_initializer(), bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of generators to generate the final detections.""" """Contains definitions of generators to generate the final detections."""
from typing import Optional, Mapping from typing import List, Optional, Mapping
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -432,8 +432,13 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -432,8 +432,13 @@ class DetectionGenerator(tf.keras.layers.Layer):
} }
super(DetectionGenerator, self).__init__(**kwargs) super(DetectionGenerator, self).__init__(**kwargs)
def __call__(self, raw_boxes: tf.Tensor, raw_scores: tf.Tensor, def __call__(self,
anchor_boxes: tf.Tensor, image_shape: tf.Tensor): raw_boxes: tf.Tensor,
raw_scores: tf.Tensor,
anchor_boxes: tf.Tensor,
image_shape: tf.Tensor,
regression_weights: Optional[List[float]] = None,
bbox_per_class: bool = True):
"""Generates final detections. """Generates final detections.
Args: Args:
...@@ -446,6 +451,8 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -446,6 +451,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
image_shape: A `tf.Tensor` of shape of `[batch_size, 2]` storing the image image_shape: A `tf.Tensor` of shape of `[batch_size, 2]` storing the image
height and width w.r.t. the scaled image, i.e. the same image space as height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`. `box_outputs` and `anchor_boxes`.
regression_weights: A list of four float numbers to scale coordinates.
bbox_per_class: A `bool`. If True, perform per-class box regression.
Returns: Returns:
If `apply_nms` = True, the return is a dictionary with keys: If `apply_nms` = True, the return is a dictionary with keys:
...@@ -473,27 +480,32 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -473,27 +480,32 @@ class DetectionGenerator(tf.keras.layers.Layer):
batch_size = box_scores_shape[0] batch_size = box_scores_shape[0]
num_locations = box_scores_shape_list[1] num_locations = box_scores_shape_list[1]
num_classes = box_scores_shape_list[-1] num_classes = box_scores_shape_list[-1]
num_detections = num_locations * (num_classes - 1)
box_scores = tf.slice(box_scores, [0, 0, 1], [-1, -1, -1]) box_scores = tf.slice(box_scores, [0, 0, 1], [-1, -1, -1])
raw_boxes = tf.reshape(raw_boxes,
[batch_size, num_locations, num_classes, 4]) if bbox_per_class:
raw_boxes = tf.slice(raw_boxes, [0, 0, 1, 0], [-1, -1, -1, -1]) num_detections = num_locations * (num_classes - 1)
anchor_boxes = tf.tile( raw_boxes = tf.reshape(raw_boxes,
tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1]) [batch_size, num_locations, num_classes, 4])
raw_boxes = tf.reshape(raw_boxes, [batch_size, num_detections, 4]) raw_boxes = tf.slice(raw_boxes, [0, 0, 1, 0], [-1, -1, -1, -1])
anchor_boxes = tf.reshape(anchor_boxes, [batch_size, num_detections, 4]) anchor_boxes = tf.tile(
tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
raw_boxes = tf.reshape(raw_boxes, [batch_size, num_detections, 4])
anchor_boxes = tf.reshape(anchor_boxes, [batch_size, num_detections, 4])
# Box decoding. # Box decoding.
decoded_boxes = box_ops.decode_boxes( decoded_boxes = box_ops.decode_boxes(
raw_boxes, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0]) raw_boxes, anchor_boxes, weights=regression_weights)
# Box clipping # Box clipping
decoded_boxes = box_ops.clip_boxes( decoded_boxes = box_ops.clip_boxes(
decoded_boxes, tf.expand_dims(image_shape, axis=1)) decoded_boxes, tf.expand_dims(image_shape, axis=1))
decoded_boxes = tf.reshape(decoded_boxes, if bbox_per_class:
[batch_size, num_locations, num_classes - 1, 4]) decoded_boxes = tf.reshape(
decoded_boxes, [batch_size, num_locations, num_classes - 1, 4])
else:
decoded_boxes = tf.expand_dims(decoded_boxes, axis=2)
if not self._config_dict['apply_nms']: if not self._config_dict['apply_nms']:
return { return {
......
...@@ -31,6 +31,7 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -31,6 +31,7 @@ class ROISampler(tf.keras.layers.Layer):
foreground_iou_threshold: float = 0.5, foreground_iou_threshold: float = 0.5,
background_iou_high_threshold: float = 0.5, background_iou_high_threshold: float = 0.5,
background_iou_low_threshold: float = 0, background_iou_low_threshold: float = 0,
skip_subsampling: bool = False,
**kwargs): **kwargs):
"""Initializes a ROI sampler. """Initializes a ROI sampler.
...@@ -48,6 +49,9 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -48,6 +49,9 @@ class ROISampler(tf.keras.layers.Layer):
background_iou_low_threshold: A `float` that represents the IoU threshold background_iou_low_threshold: A `float` that represents the IoU threshold
for a box to be considered as negative (if overlap in for a box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`]) [`background_iou_low_threshold`, `background_iou_high_threshold`])
skip_subsampling: a bool that determines if we want to skip the sampling
procedure than balances the fg/bg classes. Used for upper frcnn layers
in cascade RCNN.
**kwargs: Additional keyword arguments passed to Layer. **kwargs: Additional keyword arguments passed to Layer.
""" """
self._config_dict = { self._config_dict = {
...@@ -57,6 +61,7 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -57,6 +61,7 @@ class ROISampler(tf.keras.layers.Layer):
'foreground_iou_threshold': foreground_iou_threshold, 'foreground_iou_threshold': foreground_iou_threshold,
'background_iou_high_threshold': background_iou_high_threshold, 'background_iou_high_threshold': background_iou_high_threshold,
'background_iou_low_threshold': background_iou_low_threshold, 'background_iou_low_threshold': background_iou_low_threshold,
'skip_subsampling': skip_subsampling,
} }
self._sim_calc = keras_cv.ops.IouSimilarity() self._sim_calc = keras_cv.ops.IouSimilarity()
...@@ -110,8 +115,8 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -110,8 +115,8 @@ class ROISampler(tf.keras.layers.Layer):
tensor, i.e., tensor, i.e.,
gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i]. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
""" """
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
if self._config_dict['mix_gt_boxes']: if self._config_dict['mix_gt_boxes']:
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
boxes = tf.concat([boxes, gt_boxes], axis=1) boxes = tf.concat([boxes, gt_boxes], axis=1)
boxes_invalid_mask = tf.less( boxes_invalid_mask = tf.less(
...@@ -143,6 +148,10 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -143,6 +148,10 @@ class ROISampler(tf.keras.layers.Layer):
tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices), tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
matched_gt_indices) matched_gt_indices)
if self._config_dict['skip_subsampling']:
return (boxes, matched_gt_boxes, tf.squeeze(matched_gt_classes,
axis=-1), matched_gt_indices)
sampled_indices = self._sampler( sampled_indices = self._sampler(
positive_matches, negative_matches, ignored_matches) positive_matches, negative_matches, ignored_matches)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Mask R-CNN model.""" """Mask R-CNN model."""
from typing import Any, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Union
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -32,12 +32,15 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -32,12 +32,15 @@ class MaskRCNNModel(tf.keras.Model):
rpn_head: tf.keras.layers.Layer, rpn_head: tf.keras.layers.Layer,
detection_head: tf.keras.layers.Layer, detection_head: tf.keras.layers.Layer,
roi_generator: tf.keras.layers.Layer, roi_generator: tf.keras.layers.Layer,
roi_sampler: tf.keras.layers.Layer, roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer, roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer, detection_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None, mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None, mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None, mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
**kwargs): **kwargs):
"""Initializes the Mask R-CNN model. """Initializes the Mask R-CNN model.
...@@ -47,12 +50,17 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -47,12 +50,17 @@ class MaskRCNNModel(tf.keras.Model):
rpn_head: the RPN head. rpn_head: the RPN head.
detection_head: the detection head. detection_head: the detection head.
roi_generator: the ROI generator. roi_generator: the ROI generator.
roi_sampler: the ROI sampler. roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner. roi_aligner: the ROI aligner.
detection_generator: the detection generator. detection_generator: the detection generator.
mask_head: the mask head. mask_head: the mask head.
mask_sampler: the mask sampler. mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction. mask_roi_aligner: the ROI alginer for mask prediction.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over
all detection heads.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(MaskRCNNModel, self).__init__(**kwargs) super(MaskRCNNModel, self).__init__(**kwargs)
...@@ -68,13 +76,22 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -68,13 +76,22 @@ class MaskRCNNModel(tf.keras.Model):
'mask_head': mask_head, 'mask_head': mask_head,
'mask_sampler': mask_sampler, 'mask_sampler': mask_sampler,
'mask_roi_aligner': mask_roi_aligner, 'mask_roi_aligner': mask_roi_aligner,
'class_agnostic_bbox_pred': class_agnostic_bbox_pred,
'cascade_class_ensemble': cascade_class_ensemble,
} }
self.backbone = backbone self.backbone = backbone
self.decoder = decoder self.decoder = decoder
self.rpn_head = rpn_head self.rpn_head = rpn_head
self.detection_head = detection_head self.detection_head = detection_head
self.roi_generator = roi_generator self.roi_generator = roi_generator
self.roi_sampler = roi_sampler if not isinstance(roi_sampler, (list, tuple)):
self.roi_sampler = [roi_sampler]
else:
self.roi_sampler = roi_sampler
if len(self.roi_sampler) > 1 and not class_agnostic_bbox_pred:
raise ValueError(
'`class_agnostic_bbox_pred` needs to be True if multiple detection heads are specified.'
)
self.roi_aligner = roi_aligner self.roi_aligner = roi_aligner
self.detection_generator = detection_generator self.detection_generator = detection_generator
self._include_mask = mask_head is not None self._include_mask = mask_head is not None
...@@ -85,6 +102,13 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -85,6 +102,13 @@ class MaskRCNNModel(tf.keras.Model):
if self._include_mask and mask_roi_aligner is None: if self._include_mask and mask_roi_aligner is None:
raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.') raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.')
self.mask_roi_aligner = mask_roi_aligner self.mask_roi_aligner = mask_roi_aligner
# Weights for the regression losses for each FRCNN layer.
# TODO(xianzhi): Make the weights configurable.
self._cascade_layer_to_weights = [
[10.0, 10.0, 5.0, 5.0],
[20.0, 20.0, 10.0, 10.0],
[30.0, 30.0, 15.0, 15.0],
]
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
...@@ -110,44 +134,50 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -110,44 +134,50 @@ class MaskRCNNModel(tf.keras.Model):
}) })
# Generate RoIs. # Generate RoIs.
rois, _ = self.roi_generator( current_rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
rpn_boxes, rpn_scores, anchor_boxes, image_shape, training) image_shape, training)
if training: next_rois = current_rois
rois = tf.stop_gradient(rois) all_class_outputs = []
for cascade_num in range(len(self.roi_sampler)):
# In cascade RCNN we want the higher layers to have different regression
# weights as the predicted deltas become smaller and smaller.
regression_weights = self._cascade_layer_to_weights[cascade_num]
current_rois = next_rois
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = ( (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
self.roi_sampler(rois, gt_boxes, gt_classes)) matched_gt_classes, matched_gt_indices,
# Assign target for the 2nd stage classification. current_rois) = self._run_frcnn_head(
box_targets = box_ops.encode_boxes( features=features,
matched_gt_boxes, rois, weights=[10.0, 10.0, 5.0, 5.0]) rois=current_rois,
# If the target is background, the box target is set to all 0s. gt_boxes=gt_boxes,
box_targets = tf.where( gt_classes=gt_classes,
tf.tile( training=training,
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1), model_outputs=model_outputs,
[1, 1, 4]), layer_num=cascade_num,
tf.zeros_like(box_targets), regression_weights=regression_weights)
box_targets) all_class_outputs.append(class_outputs)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
# RoI align. # Generate ROIs for the next cascade head if there is any.
roi_features = self.roi_aligner(features, rois) if cascade_num < len(self.roi_sampler) - 1:
next_rois = box_ops.decode_boxes(
tf.cast(box_outputs, tf.float32),
current_rois,
weights=regression_weights)
next_rois = box_ops.clip_boxes(next_rois,
tf.expand_dims(image_shape, axis=1))
# Detection head. if not training:
raw_scores, raw_boxes = self.detection_head(roi_features) if self._config_dict['cascade_class_ensemble']:
class_outputs = tf.add_n(all_class_outputs) / len(all_class_outputs)
if training:
model_outputs.update({
'class_outputs': raw_scores,
'box_outputs': raw_boxes,
})
else:
# Post-processing.
detections = self.detection_generator( detections = self.detection_generator(
raw_boxes, raw_scores, rois, image_shape) box_outputs,
class_outputs,
current_rois,
image_shape,
regression_weights,
bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
model_outputs.update({ model_outputs.update({
'detection_boxes': detections['detection_boxes'], 'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'], 'detection_scores': detections['detection_scores'],
...@@ -159,12 +189,9 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -159,12 +189,9 @@ class MaskRCNNModel(tf.keras.Model):
return model_outputs return model_outputs
if training: if training:
rois, roi_classes, roi_masks = self.mask_sampler( current_rois, roi_classes, roi_masks = self.mask_sampler(
rois, current_rois, matched_gt_boxes, matched_gt_classes,
matched_gt_boxes, matched_gt_indices, gt_masks)
matched_gt_classes,
matched_gt_indices,
gt_masks)
roi_masks = tf.stop_gradient(roi_masks) roi_masks = tf.stop_gradient(roi_masks)
model_outputs.update({ model_outputs.update({
...@@ -172,11 +199,11 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -172,11 +199,11 @@ class MaskRCNNModel(tf.keras.Model):
'mask_targets': roi_masks, 'mask_targets': roi_masks,
}) })
else: else:
rois = model_outputs['detection_boxes'] current_rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes'] roi_classes = model_outputs['detection_classes']
# Mask RoI align. # Mask RoI align.
mask_roi_features = self.mask_roi_aligner(features, rois) mask_roi_features = self.mask_roi_aligner(features, current_rois)
# Mask head. # Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes]) raw_masks = self.mask_head([mask_roi_features, roi_classes])
...@@ -191,6 +218,80 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -191,6 +218,80 @@ class MaskRCNNModel(tf.keras.Model):
}) })
return model_outputs return model_outputs
def _run_frcnn_head(self, features, rois, gt_boxes, gt_classes, training,
model_outputs, layer_num, regression_weights):
"""Runs the frcnn head that does both class and box prediction.
Args:
features: `list` of features from the feature extractor.
rois: `list` of current rois that will be used to predict bbox refinement
and classes from.
gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4].
This tensor might have paddings with a negative value.
gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
classes. It is padded with -1s to indicate the invalid classes.
training: `bool`, if model is training or being evaluated.
model_outputs: `dict`, used for storing outputs used for eval and losses.
layer_num: `int`, the current frcnn layer in the cascade.
regression_weights: `list`, weights used for l1 loss in bounding box
regression.
Returns:
class_outputs: Class predictions for rois.
box_outputs: Box predictions for rois. These are formatted for the
regression loss and need to be converted before being used as rois
in the next stage.
model_outputs: Updated dict with predictions used for losses and eval.
matched_gt_boxes: If `is_training` is true, then these give the gt box
location of its positive match.
matched_gt_classes: If `is_training` is true, then these give the gt class
of the predicted box.
matched_gt_boxes: If `is_training` is true, then these give the box
location of its positive match.
matched_gt_indices: If `is_training` is true, then gives the index of
the positive box match. Used for mask prediction.
rois: The sampled rois used for this layer.
"""
# Only used during training.
matched_gt_boxes, matched_gt_classes, matched_gt_indices = (None, None,
None)
if training:
rois = tf.stop_gradient(rois)
current_roi_sampler = self.roi_sampler[layer_num]
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
current_roi_sampler(rois, gt_boxes, gt_classes))
# Create bounding box training targets.
box_targets = box_ops.encode_boxes(
matched_gt_boxes, rois, weights=regression_weights)
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]), tf.zeros_like(box_targets), box_targets)
model_outputs.update({
'class_targets_{}'.format(layer_num)
if layer_num else 'class_targets':
matched_gt_classes,
'box_targets_{}'.format(layer_num) if layer_num else 'box_targets':
box_targets,
})
# Get roi features.
roi_features = self.roi_aligner(features, rois)
# Run frcnn head to get class and bbox predictions.
class_outputs, box_outputs = self.detection_head(roi_features)
model_outputs.update({
'class_outputs_{}'.format(layer_num) if layer_num else 'class_outputs':
class_outputs,
'box_outputs_{}'.format(layer_num) if layer_num else 'box_outputs':
box_outputs,
})
return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices, rois)
@property @property
def checkpoint_items( def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]: self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
......
...@@ -125,17 +125,30 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -125,17 +125,30 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(expected_num_params, model.count_params()) self.assertEqual(expected_num_params, model.count_params())
@parameterized.parameters( @parameterized.parameters(
(False, False,), (False, False, False),
(False, True,), (False, True, False),
(True, False,), (False, False, True),
(True, True,), (False, True, True),
(True, False, False),
(True, True, False),
(True, False, True),
(True, True, True),
) )
def test_forward(self, include_mask, training): def test_forward(self, include_mask, training, use_cascade_heads):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 3
max_level = 4 max_level = 4
num_scales = 3 num_scales = 3
aspect_ratios = [1.0] aspect_ratios = [1.0]
if use_cascade_heads:
cascade_iou_thresholds = [0.6]
class_agnostic_bbox_pred = True
cascade_class_ensemble = True
else:
cascade_iou_thresholds = None
class_agnostic_bbox_pred = False
cascade_class_ensemble = False
image_size = (256, 256) image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3) images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]]) image_shape = np.array([[224, 100], [100, 224]])
...@@ -159,9 +172,23 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -159,9 +172,23 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
max_level=max_level, max_level=max_level,
num_anchors_per_location=num_anchors_per_location) num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead( detection_head = instance_heads.DetectionHead(
num_classes=num_classes) num_classes=num_classes,
class_agnostic_bbox_pred=class_agnostic_bbox_pred)
roi_generator_obj = roi_generator.MultilevelROIGenerator() roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler() roi_sampler_obj = roi_sampler.ROISampler()
roi_sampler_cascade.append(roi_sampler_obj)
if cascade_iou_thresholds:
for iou in cascade_iou_thresholds:
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=False,
foreground_iou_threshold=iou,
background_iou_high_threshold=iou,
background_iou_low_threshold=0.0,
skip_subsampling=True)
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner() roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator() detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask: if include_mask:
...@@ -180,12 +207,14 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -180,12 +207,14 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
rpn_head, rpn_head,
detection_head, detection_head,
roi_generator_obj, roi_generator_obj,
roi_sampler_obj, roi_sampler_cascade,
roi_aligner_obj, roi_aligner_obj,
detection_generator_obj, detection_generator_obj,
mask_head, mask_head,
mask_sampler_obj, mask_sampler_obj,
mask_roi_aligner_obj) mask_roi_aligner_obj,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble)
gt_boxes = np.array( gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]], [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
......
...@@ -161,6 +161,7 @@ class MaskRCNNTask(base_task.Task): ...@@ -161,6 +161,7 @@ class MaskRCNNTask(base_task.Task):
aux_losses: Optional[Any] = None): aux_losses: Optional[Any] = None):
"""Build Mask R-CNN losses.""" """Build Mask R-CNN losses."""
params = self.task_config params = self.task_config
cascade_ious = params.model.roi_sampler.cascade_iou_thresholds
rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss( rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
tf.shape(outputs['box_outputs'])[1]) tf.shape(outputs['box_outputs'])[1])
...@@ -175,15 +176,32 @@ class MaskRCNNTask(base_task.Task): ...@@ -175,15 +176,32 @@ class MaskRCNNTask(base_task.Task):
frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss() frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss( frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
params.losses.frcnn_huber_loss_delta) params.losses.frcnn_huber_loss_delta,
frcnn_cls_loss = tf.reduce_mean( params.model.detection_head.class_agnostic_bbox_pred)
frcnn_cls_loss_fn(
outputs['class_outputs'], outputs['class_targets'])) # Final cls/box losses are computed as an average of all detection heads.
frcnn_box_loss = tf.reduce_mean( frcnn_cls_loss = 0.0
frcnn_box_loss_fn( frcnn_box_loss = 0.0
outputs['box_outputs'], num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious)
outputs['class_targets'], for cas_num in range(num_det_heads):
outputs['box_targets'])) frcnn_cls_loss_i = tf.reduce_mean(
frcnn_cls_loss_fn(
outputs['class_outputs_{}'
.format(cas_num) if cas_num else 'class_outputs'],
outputs['class_targets_{}'
.format(cas_num) if cas_num else 'class_targets']))
frcnn_box_loss_i = tf.reduce_mean(
frcnn_box_loss_fn(
outputs['box_outputs_{}'.format(cas_num
) if cas_num else 'box_outputs'],
outputs['class_targets_{}'
.format(cas_num) if cas_num else 'class_targets'],
outputs['box_targets_{}'.format(cas_num
) if cas_num else 'box_targets']))
frcnn_cls_loss += frcnn_cls_loss_i
frcnn_box_loss += frcnn_box_loss_i
frcnn_cls_loss /= num_det_heads
frcnn_box_loss /= num_det_heads
if params.model.include_mask: if params.model.include_mask:
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss() mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment