Commit d1a3fa6a authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370677512
parent 834c0454
......@@ -100,6 +100,10 @@ class DetectionHead(hyperparams.Config):
use_separable_conv: bool = False
num_fcs: int = 1
fc_dims: int = 1024
class_agnostic_bbox_pred: bool = False # Has to be True for Cascade RCNN.
# If additional IoUs are passed in 'cascade_iou_thresholds'
# then ensemble the class probabilities from all heads.
cascade_class_ensemble: bool = False
@dataclasses.dataclass
......@@ -125,6 +129,9 @@ class ROISampler(hyperparams.Config):
foreground_iou_threshold: float = 0.5
background_iou_high_threshold: float = 0.5
background_iou_low_threshold: float = 0.0
# IoU thresholds for additional FRCNN heads in Cascade mode.
# `foreground_iou_threshold` is the first threshold.
cascade_iou_thresholds: Optional[List[float]] = None
@dataclasses.dataclass
......@@ -282,7 +289,6 @@ def fasterrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
......@@ -292,6 +298,68 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
steps_per_epoch = 500
coco_val_samples = 5000
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=MaskRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=MaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], include_mask=True),
losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=64,
parser=Parser(
aug_rand_hflip=True, aug_scale_min=0.8, aug_scale_max=1.25)),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=8)),
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // 8,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [15000, 20000],
'values': [0.12, 0.012, 0.0012],
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 500,
'warmup_learning_rate': 0.0067
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('cascadercnn_resnetfpn_coco')
def cascadercnn_resnetfpn_coco() -> cfg.ExperimentConfig:
"""COCO object detection with Cascade R-CNN."""
steps_per_epoch = 500
coco_val_samples = 5000
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=MaskRCNNTask(
......@@ -302,7 +370,10 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
model=MaskRCNN(
num_classes=91,
input_size=[1024, 1024, 3],
include_mask=True),
include_mask=True,
roi_sampler=ROISampler(cascade_iou_thresholds=[0.6, 0.7]),
detection_head=DetectionHead(
class_agnostic_bbox_pred=True, cascade_class_ensemble=True)),
losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
......@@ -347,7 +418,6 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
......
......@@ -173,13 +173,21 @@ class FastrcnnClassLoss(object):
class FastrcnnBoxLoss(object):
"""Fast R-CNN box regression loss function."""
def __init__(self, huber_loss_delta: float):
# The delta is typically around the mean value of regression target.
# for instances, the regression targets of 512x512 input with 6 anchors on
# P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
def __init__(self,
huber_loss_delta: float,
class_agnostic_bbox_pred: bool = False):
"""Initiate Faster RCNN box loss.
Args:
huber_loss_delta: the delta is typically around the mean value of
regression target. for instances, the regression targets of 512x512
input with 6 anchors on P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
class_agnostic_bbox_pred: if True, class agnostic bounding box prediction
is performed.
"""
self._huber_loss = tf.keras.losses.Huber(
delta=huber_loss_delta, reduction=tf.keras.losses.Reduction.SUM)
self._class_agnostic_bbox_pred = class_agnostic_bbox_pred
def __call__(self, box_outputs, class_targets, box_targets):
"""Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
......@@ -207,9 +215,13 @@ class FastrcnnBoxLoss(object):
"""
with tf.name_scope('fast_rcnn_loss'):
class_targets = tf.cast(class_targets, dtype=tf.int32)
if not self._class_agnostic_bbox_pred:
box_outputs = self._assign_class_targets(box_outputs, class_targets)
return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
# Selects the box from `box_outputs` based on `class_targets`, with which
# the box has the maximum overlap.
def _assign_class_targets(self, box_outputs, class_targets):
"""Selects the box from `box_outputs` based on `class_targets`, with which the box has the maximum overlap."""
(batch_size, num_rois,
num_class_specific_boxes) = box_outputs.get_shape().as_list()
num_classes = num_class_specific_boxes // 4
......@@ -218,8 +230,7 @@ class FastrcnnBoxLoss(object):
box_indices = tf.reshape(
class_targets + tf.tile(
tf.expand_dims(
tf.range(batch_size) * num_rois * num_classes, 1),
tf.expand_dims(tf.range(batch_size) * num_rois * num_classes, 1),
[1, num_rois]) + tf.tile(
tf.expand_dims(tf.range(num_rois) * num_classes, 0),
[batch_size, 1]), [-1])
......@@ -231,7 +242,7 @@ class FastrcnnBoxLoss(object):
dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
return box_outputs
def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
normalizer=1.0):
......@@ -299,4 +310,3 @@ class MaskrcnnLoss(object):
# The loss is normalized by the number of 1's in weights and
# + 0.01 is used to avoid division by zero.
return mask_loss / (tf.reduce_sum(weights) + 0.01)
......@@ -109,6 +109,7 @@ def build_maskrcnn(
use_separable_conv=detection_head_config.use_separable_conv,
num_fcs=detection_head_config.num_fcs,
fc_dims=detection_head_config.fc_dims,
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
......@@ -131,6 +132,7 @@ def build_maskrcnn(
test_num_proposals=roi_generator_config.test_num_proposals,
use_batched_nms=roi_generator_config.use_batched_nms)
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
num_sampled_rois=roi_sampler_config.num_sampled_rois,
......@@ -140,6 +142,18 @@ def build_maskrcnn(
roi_sampler_config.background_iou_high_threshold),
background_iou_low_threshold=(
roi_sampler_config.background_iou_low_threshold))
roi_sampler_cascade.append(roi_sampler_obj)
# Initialize addtional roi simplers for cascade heads.
if roi_sampler_config.cascade_iou_thresholds:
for iou in roi_sampler_config.cascade_iou_thresholds:
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=False,
num_sampled_rois=roi_sampler_config.num_sampled_rois,
foreground_iou_threshold=iou,
background_iou_high_threshold=iou,
background_iou_low_threshold=0.0,
skip_subsampling=True)
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner(
crop_size=roi_aligner_config.crop_size,
......@@ -186,12 +200,14 @@ def build_maskrcnn(
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator_obj,
roi_sampler=roi_sampler_obj,
roi_sampler=roi_sampler_cascade,
roi_aligner=roi_aligner_obj,
detection_generator=detection_generator_obj,
mask_head=mask_head,
mask_sampler=mask_sampler_obj,
mask_roi_aligner=mask_roi_aligner_obj)
mask_roi_aligner=mask_roi_aligner_obj,
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
cascade_class_ensemble=detection_head_config.cascade_class_ensemble)
return model
......
......@@ -33,6 +33,7 @@ class DetectionHead(tf.keras.layers.Layer):
use_separable_conv: bool = False,
num_fcs: int = 2,
fc_dims: int = 1024,
class_agnostic_bbox_pred: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
......@@ -54,6 +55,8 @@ class DetectionHead(tf.keras.layers.Layer):
the predictions.
fc_dims: An `int` number that represents the number of dimension of the FC
layers.
class_agnostic_bbox_pred: `bool`, indicating whether bboxes should be
predicted for every class or not.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
......@@ -73,6 +76,7 @@ class DetectionHead(tf.keras.layers.Layer):
'use_separable_conv': use_separable_conv,
'num_fcs': num_fcs,
'fc_dims': fc_dims,
'class_agnostic_bbox_pred': class_agnostic_bbox_pred,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
......@@ -155,8 +159,11 @@ class DetectionHead(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='detection-scores')
num_box_outputs = (4 if self._config_dict['class_agnostic_bbox_pred'] else
self._config_dict['num_classes'] * 4)
self._box_regressor = tf.keras.layers.Dense(
units=self._config_dict['num_classes'] * 4,
units=num_box_outputs,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of generators to generate the final detections."""
from typing import Optional, Mapping
from typing import List, Optional, Mapping
# Import libraries
import tensorflow as tf
......@@ -432,8 +432,13 @@ class DetectionGenerator(tf.keras.layers.Layer):
}
super(DetectionGenerator, self).__init__(**kwargs)
def __call__(self, raw_boxes: tf.Tensor, raw_scores: tf.Tensor,
anchor_boxes: tf.Tensor, image_shape: tf.Tensor):
def __call__(self,
raw_boxes: tf.Tensor,
raw_scores: tf.Tensor,
anchor_boxes: tf.Tensor,
image_shape: tf.Tensor,
regression_weights: Optional[List[float]] = None,
bbox_per_class: bool = True):
"""Generates final detections.
Args:
......@@ -446,6 +451,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
image_shape: A `tf.Tensor` of shape of `[batch_size, 2]` storing the image
height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
regression_weights: A list of four float numbers to scale coordinates.
bbox_per_class: A `bool`. If True, perform per-class box regression.
Returns:
If `apply_nms` = True, the return is a dictionary with keys:
......@@ -473,9 +480,11 @@ class DetectionGenerator(tf.keras.layers.Layer):
batch_size = box_scores_shape[0]
num_locations = box_scores_shape_list[1]
num_classes = box_scores_shape_list[-1]
num_detections = num_locations * (num_classes - 1)
box_scores = tf.slice(box_scores, [0, 0, 1], [-1, -1, -1])
if bbox_per_class:
num_detections = num_locations * (num_classes - 1)
raw_boxes = tf.reshape(raw_boxes,
[batch_size, num_locations, num_classes, 4])
raw_boxes = tf.slice(raw_boxes, [0, 0, 1, 0], [-1, -1, -1, -1])
......@@ -486,14 +495,17 @@ class DetectionGenerator(tf.keras.layers.Layer):
# Box decoding.
decoded_boxes = box_ops.decode_boxes(
raw_boxes, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])
raw_boxes, anchor_boxes, weights=regression_weights)
# Box clipping
decoded_boxes = box_ops.clip_boxes(
decoded_boxes, tf.expand_dims(image_shape, axis=1))
decoded_boxes = tf.reshape(decoded_boxes,
[batch_size, num_locations, num_classes - 1, 4])
if bbox_per_class:
decoded_boxes = tf.reshape(
decoded_boxes, [batch_size, num_locations, num_classes - 1, 4])
else:
decoded_boxes = tf.expand_dims(decoded_boxes, axis=2)
if not self._config_dict['apply_nms']:
return {
......
......@@ -31,6 +31,7 @@ class ROISampler(tf.keras.layers.Layer):
foreground_iou_threshold: float = 0.5,
background_iou_high_threshold: float = 0.5,
background_iou_low_threshold: float = 0,
skip_subsampling: bool = False,
**kwargs):
"""Initializes a ROI sampler.
......@@ -48,6 +49,9 @@ class ROISampler(tf.keras.layers.Layer):
background_iou_low_threshold: A `float` that represents the IoU threshold
for a box to be considered as negative (if overlap in
[`background_iou_low_threshold`, `background_iou_high_threshold`])
skip_subsampling: a bool that determines if we want to skip the sampling
procedure than balances the fg/bg classes. Used for upper frcnn layers
in cascade RCNN.
**kwargs: Additional keyword arguments passed to Layer.
"""
self._config_dict = {
......@@ -57,6 +61,7 @@ class ROISampler(tf.keras.layers.Layer):
'foreground_iou_threshold': foreground_iou_threshold,
'background_iou_high_threshold': background_iou_high_threshold,
'background_iou_low_threshold': background_iou_low_threshold,
'skip_subsampling': skip_subsampling,
}
self._sim_calc = keras_cv.ops.IouSimilarity()
......@@ -110,8 +115,8 @@ class ROISampler(tf.keras.layers.Layer):
tensor, i.e.,
gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
"""
if self._config_dict['mix_gt_boxes']:
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
if self._config_dict['mix_gt_boxes']:
boxes = tf.concat([boxes, gt_boxes], axis=1)
boxes_invalid_mask = tf.less(
......@@ -143,6 +148,10 @@ class ROISampler(tf.keras.layers.Layer):
tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
matched_gt_indices)
if self._config_dict['skip_subsampling']:
return (boxes, matched_gt_boxes, tf.squeeze(matched_gt_classes,
axis=-1), matched_gt_indices)
sampled_indices = self._sampler(
positive_matches, negative_matches, ignored_matches)
......
......@@ -14,7 +14,7 @@
"""Mask R-CNN model."""
from typing import Any, Mapping, Optional, Union
from typing import Any, List, Mapping, Optional, Union
# Import libraries
import tensorflow as tf
......@@ -32,12 +32,15 @@ class MaskRCNNModel(tf.keras.Model):
rpn_head: tf.keras.layers.Layer,
detection_head: tf.keras.layers.Layer,
roi_generator: tf.keras.layers.Layer,
roi_sampler: tf.keras.layers.Layer,
roi_sampler: Union[tf.keras.layers.Layer,
List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
class_agnostic_bbox_pred: bool = False,
cascade_class_ensemble: bool = False,
**kwargs):
"""Initializes the Mask R-CNN model.
......@@ -47,12 +50,17 @@ class MaskRCNNModel(tf.keras.Model):
rpn_head: the RPN head.
detection_head: the detection head.
roi_generator: the ROI generator.
roi_sampler: the ROI sampler.
roi_sampler: a single ROI sampler or a list of ROI samplers for cascade
detection heads.
roi_aligner: the ROI aligner.
detection_generator: the detection generator.
mask_head: the mask head.
mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
cascade_class_ensemble: if True, ensemble classification scores over
all detection heads.
**kwargs: keyword arguments to be passed.
"""
super(MaskRCNNModel, self).__init__(**kwargs)
......@@ -68,13 +76,22 @@ class MaskRCNNModel(tf.keras.Model):
'mask_head': mask_head,
'mask_sampler': mask_sampler,
'mask_roi_aligner': mask_roi_aligner,
'class_agnostic_bbox_pred': class_agnostic_bbox_pred,
'cascade_class_ensemble': cascade_class_ensemble,
}
self.backbone = backbone
self.decoder = decoder
self.rpn_head = rpn_head
self.detection_head = detection_head
self.roi_generator = roi_generator
if not isinstance(roi_sampler, (list, tuple)):
self.roi_sampler = [roi_sampler]
else:
self.roi_sampler = roi_sampler
if len(self.roi_sampler) > 1 and not class_agnostic_bbox_pred:
raise ValueError(
'`class_agnostic_bbox_pred` needs to be True if multiple detection heads are specified.'
)
self.roi_aligner = roi_aligner
self.detection_generator = detection_generator
self._include_mask = mask_head is not None
......@@ -85,6 +102,13 @@ class MaskRCNNModel(tf.keras.Model):
if self._include_mask and mask_roi_aligner is None:
raise ValueError('`mask_roi_aligner` is not provided in Mask R-CNN.')
self.mask_roi_aligner = mask_roi_aligner
# Weights for the regression losses for each FRCNN layer.
# TODO(xianzhi): Make the weights configurable.
self._cascade_layer_to_weights = [
[10.0, 10.0, 5.0, 5.0],
[20.0, 20.0, 10.0, 10.0],
[30.0, 30.0, 15.0, 15.0],
]
def call(self,
images: tf.Tensor,
......@@ -110,44 +134,50 @@ class MaskRCNNModel(tf.keras.Model):
})
# Generate RoIs.
rois, _ = self.roi_generator(
rpn_boxes, rpn_scores, anchor_boxes, image_shape, training)
current_rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
image_shape, training)
if training:
rois = tf.stop_gradient(rois)
next_rois = current_rois
all_class_outputs = []
for cascade_num in range(len(self.roi_sampler)):
# In cascade RCNN we want the higher layers to have different regression
# weights as the predicted deltas become smaller and smaller.
regression_weights = self._cascade_layer_to_weights[cascade_num]
current_rois = next_rois
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
self.roi_sampler(rois, gt_boxes, gt_classes))
# Assign target for the 2nd stage classification.
box_targets = box_ops.encode_boxes(
matched_gt_boxes, rois, weights=[10.0, 10.0, 5.0, 5.0])
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]),
tf.zeros_like(box_targets),
box_targets)
model_outputs.update({
'class_targets': matched_gt_classes,
'box_targets': box_targets,
})
(class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices,
current_rois) = self._run_frcnn_head(
features=features,
rois=current_rois,
gt_boxes=gt_boxes,
gt_classes=gt_classes,
training=training,
model_outputs=model_outputs,
layer_num=cascade_num,
regression_weights=regression_weights)
all_class_outputs.append(class_outputs)
# RoI align.
roi_features = self.roi_aligner(features, rois)
# Generate ROIs for the next cascade head if there is any.
if cascade_num < len(self.roi_sampler) - 1:
next_rois = box_ops.decode_boxes(
tf.cast(box_outputs, tf.float32),
current_rois,
weights=regression_weights)
next_rois = box_ops.clip_boxes(next_rois,
tf.expand_dims(image_shape, axis=1))
# Detection head.
raw_scores, raw_boxes = self.detection_head(roi_features)
if not training:
if self._config_dict['cascade_class_ensemble']:
class_outputs = tf.add_n(all_class_outputs) / len(all_class_outputs)
if training:
model_outputs.update({
'class_outputs': raw_scores,
'box_outputs': raw_boxes,
})
else:
# Post-processing.
detections = self.detection_generator(
raw_boxes, raw_scores, rois, image_shape)
box_outputs,
class_outputs,
current_rois,
image_shape,
regression_weights,
bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
model_outputs.update({
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
......@@ -159,12 +189,9 @@ class MaskRCNNModel(tf.keras.Model):
return model_outputs
if training:
rois, roi_classes, roi_masks = self.mask_sampler(
rois,
matched_gt_boxes,
matched_gt_classes,
matched_gt_indices,
gt_masks)
current_rois, roi_classes, roi_masks = self.mask_sampler(
current_rois, matched_gt_boxes, matched_gt_classes,
matched_gt_indices, gt_masks)
roi_masks = tf.stop_gradient(roi_masks)
model_outputs.update({
......@@ -172,11 +199,11 @@ class MaskRCNNModel(tf.keras.Model):
'mask_targets': roi_masks,
})
else:
rois = model_outputs['detection_boxes']
current_rois = model_outputs['detection_boxes']
roi_classes = model_outputs['detection_classes']
# Mask RoI align.
mask_roi_features = self.mask_roi_aligner(features, rois)
mask_roi_features = self.mask_roi_aligner(features, current_rois)
# Mask head.
raw_masks = self.mask_head([mask_roi_features, roi_classes])
......@@ -191,6 +218,80 @@ class MaskRCNNModel(tf.keras.Model):
})
return model_outputs
def _run_frcnn_head(self, features, rois, gt_boxes, gt_classes, training,
model_outputs, layer_num, regression_weights):
"""Runs the frcnn head that does both class and box prediction.
Args:
features: `list` of features from the feature extractor.
rois: `list` of current rois that will be used to predict bbox refinement
and classes from.
gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4].
This tensor might have paddings with a negative value.
gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
classes. It is padded with -1s to indicate the invalid classes.
training: `bool`, if model is training or being evaluated.
model_outputs: `dict`, used for storing outputs used for eval and losses.
layer_num: `int`, the current frcnn layer in the cascade.
regression_weights: `list`, weights used for l1 loss in bounding box
regression.
Returns:
class_outputs: Class predictions for rois.
box_outputs: Box predictions for rois. These are formatted for the
regression loss and need to be converted before being used as rois
in the next stage.
model_outputs: Updated dict with predictions used for losses and eval.
matched_gt_boxes: If `is_training` is true, then these give the gt box
location of its positive match.
matched_gt_classes: If `is_training` is true, then these give the gt class
of the predicted box.
matched_gt_boxes: If `is_training` is true, then these give the box
location of its positive match.
matched_gt_indices: If `is_training` is true, then gives the index of
the positive box match. Used for mask prediction.
rois: The sampled rois used for this layer.
"""
# Only used during training.
matched_gt_boxes, matched_gt_classes, matched_gt_indices = (None, None,
None)
if training:
rois = tf.stop_gradient(rois)
current_roi_sampler = self.roi_sampler[layer_num]
rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
current_roi_sampler(rois, gt_boxes, gt_classes))
# Create bounding box training targets.
box_targets = box_ops.encode_boxes(
matched_gt_boxes, rois, weights=regression_weights)
# If the target is background, the box target is set to all 0s.
box_targets = tf.where(
tf.tile(
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
[1, 1, 4]), tf.zeros_like(box_targets), box_targets)
model_outputs.update({
'class_targets_{}'.format(layer_num)
if layer_num else 'class_targets':
matched_gt_classes,
'box_targets_{}'.format(layer_num) if layer_num else 'box_targets':
box_targets,
})
# Get roi features.
roi_features = self.roi_aligner(features, rois)
# Run frcnn head to get class and bbox predictions.
class_outputs, box_outputs = self.detection_head(roi_features)
model_outputs.update({
'class_outputs_{}'.format(layer_num) if layer_num else 'class_outputs':
class_outputs,
'box_outputs_{}'.format(layer_num) if layer_num else 'box_outputs':
box_outputs,
})
return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
matched_gt_classes, matched_gt_indices, rois)
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
......
......@@ -125,17 +125,30 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(expected_num_params, model.count_params())
@parameterized.parameters(
(False, False,),
(False, True,),
(True, False,),
(True, True,),
(False, False, False),
(False, True, False),
(False, False, True),
(False, True, True),
(True, False, False),
(True, True, False),
(True, False, True),
(True, True, True),
)
def test_forward(self, include_mask, training):
def test_forward(self, include_mask, training, use_cascade_heads):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
if use_cascade_heads:
cascade_iou_thresholds = [0.6]
class_agnostic_bbox_pred = True
cascade_class_ensemble = True
else:
cascade_iou_thresholds = None
class_agnostic_bbox_pred = False
cascade_class_ensemble = False
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
......@@ -159,9 +172,23 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes)
num_classes=num_classes,
class_agnostic_bbox_pred=class_agnostic_bbox_pred)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler()
roi_sampler_cascade.append(roi_sampler_obj)
if cascade_iou_thresholds:
for iou in cascade_iou_thresholds:
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=False,
foreground_iou_threshold=iou,
background_iou_high_threshold=iou,
background_iou_low_threshold=0.0,
skip_subsampling=True)
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask:
......@@ -180,12 +207,14 @@ class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_sampler_cascade,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj)
mask_roi_aligner_obj,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
......
......@@ -161,6 +161,7 @@ class MaskRCNNTask(base_task.Task):
aux_losses: Optional[Any] = None):
"""Build Mask R-CNN losses."""
params = self.task_config
cascade_ious = params.model.roi_sampler.cascade_iou_thresholds
rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
tf.shape(outputs['box_outputs'])[1])
......@@ -175,15 +176,32 @@ class MaskRCNNTask(base_task.Task):
frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
params.losses.frcnn_huber_loss_delta)
frcnn_cls_loss = tf.reduce_mean(
params.losses.frcnn_huber_loss_delta,
params.model.detection_head.class_agnostic_bbox_pred)
# Final cls/box losses are computed as an average of all detection heads.
frcnn_cls_loss = 0.0
frcnn_box_loss = 0.0
num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious)
for cas_num in range(num_det_heads):
frcnn_cls_loss_i = tf.reduce_mean(
frcnn_cls_loss_fn(
outputs['class_outputs'], outputs['class_targets']))
frcnn_box_loss = tf.reduce_mean(
outputs['class_outputs_{}'
.format(cas_num) if cas_num else 'class_outputs'],
outputs['class_targets_{}'
.format(cas_num) if cas_num else 'class_targets']))
frcnn_box_loss_i = tf.reduce_mean(
frcnn_box_loss_fn(
outputs['box_outputs'],
outputs['class_targets'],
outputs['box_targets']))
outputs['box_outputs_{}'.format(cas_num
) if cas_num else 'box_outputs'],
outputs['class_targets_{}'
.format(cas_num) if cas_num else 'class_targets'],
outputs['box_targets_{}'.format(cas_num
) if cas_num else 'box_targets']))
frcnn_cls_loss += frcnn_cls_loss_i
frcnn_box_loss += frcnn_box_loss_i
frcnn_cls_loss /= num_det_heads
frcnn_box_loss /= num_det_heads
if params.model.include_mask:
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment