Commit 6faf56a6 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Make apply_nms configurable.

PiperOrigin-RevId: 371735494
parent a90e36a4
......@@ -144,6 +144,7 @@ class ROIAligner(hyperparams.Config):
@dataclasses.dataclass
class DetectionGenerator(hyperparams.Config):
apply_nms: bool = True
pre_nms_top_k: int = 5000
pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5
......
......@@ -106,6 +106,7 @@ class RetinaNetHead(hyperparams.Config):
@dataclasses.dataclass
class DetectionGenerator(hyperparams.Config):
apply_nms: bool = True
pre_nms_top_k: int = 5000
pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5
......
......@@ -160,7 +160,7 @@ def build_maskrcnn(
sample_offset=roi_aligner_config.sample_offset)
detection_generator_obj = detection_generator.DetectionGenerator(
apply_nms=True,
apply_nms=generator_config.apply_nms,
pre_nms_top_k=generator_config.pre_nms_top_k,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold,
......@@ -255,7 +255,7 @@ def build_retinanet(
kernel_regularizer=l2_regularizer)
detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
apply_nms=True,
apply_nms=generator_config.apply_nms,
pre_nms_top_k=generator_config.pre_nms_top_k,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold,
......
......@@ -215,11 +215,21 @@ class MaskRCNNModel(tf.keras.Model):
image_shape,
regression_weights,
bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
model_outputs.update({
'cls_outputs': class_outputs,
'box_outputs': box_outputs,
})
if self.detection_generator.get_config()['apply_nms']:
model_outputs.update({
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections'],
'num_detections': detections['num_detections']
})
else:
model_outputs.update({
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
})
if not self._include_mask:
......
......@@ -148,13 +148,22 @@ class RetinaNetModel(tf.keras.Model):
final_results = self.detection_generator(
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
outputs = {
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'num_detections': final_results['num_detections'],
'cls_outputs': raw_scores,
'box_outputs': raw_boxes,
}
if self.detection_generator.get_config()['apply_nms']:
outputs.update({
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'num_detections': final_results['num_detections']
})
else:
outputs.update({
'decoded_boxes': final_results['decoded_boxes'],
'decoded_box_scores': final_results['decoded_box_scores']
})
if raw_attributes:
outputs.update({
'att_outputs': raw_attributes,
......
......@@ -126,14 +126,23 @@ class DetectionModule(export_base.ExportModule):
anchor_boxes=anchor_boxes,
training=False)
if self.params.task.model.detection_generator.apply_nms:
final_outputs = {
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections'],
'image_info': image_info
'num_detections': detections['num_detections']
}
else:
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores'],
'cls_outputs': detections['cls_outputs'],
'box_outputs': detections['box_outputs']
}
if 'detection_masks' in detections.keys():
final_outputs['detection_masks'] = detections['detection_masks']
final_outputs.update({'image_info': image_info})
return final_outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment