Commit 6faf56a6 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Make apply_nms configurable.

PiperOrigin-RevId: 371735494
parent a90e36a4
...@@ -144,6 +144,7 @@ class ROIAligner(hyperparams.Config): ...@@ -144,6 +144,7 @@ class ROIAligner(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class DetectionGenerator(hyperparams.Config): class DetectionGenerator(hyperparams.Config):
apply_nms: bool = True
pre_nms_top_k: int = 5000 pre_nms_top_k: int = 5000
pre_nms_score_threshold: float = 0.05 pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
......
...@@ -106,6 +106,7 @@ class RetinaNetHead(hyperparams.Config): ...@@ -106,6 +106,7 @@ class RetinaNetHead(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class DetectionGenerator(hyperparams.Config): class DetectionGenerator(hyperparams.Config):
apply_nms: bool = True
pre_nms_top_k: int = 5000 pre_nms_top_k: int = 5000
pre_nms_score_threshold: float = 0.05 pre_nms_score_threshold: float = 0.05
nms_iou_threshold: float = 0.5 nms_iou_threshold: float = 0.5
......
...@@ -160,7 +160,7 @@ def build_maskrcnn( ...@@ -160,7 +160,7 @@ def build_maskrcnn(
sample_offset=roi_aligner_config.sample_offset) sample_offset=roi_aligner_config.sample_offset)
detection_generator_obj = detection_generator.DetectionGenerator( detection_generator_obj = detection_generator.DetectionGenerator(
apply_nms=True, apply_nms=generator_config.apply_nms,
pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_top_k=generator_config.pre_nms_top_k,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
...@@ -255,7 +255,7 @@ def build_retinanet( ...@@ -255,7 +255,7 @@ def build_retinanet(
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
detection_generator_obj = detection_generator.MultilevelDetectionGenerator( detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
apply_nms=True, apply_nms=generator_config.apply_nms,
pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_top_k=generator_config.pre_nms_top_k,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold, nms_iou_threshold=generator_config.nms_iou_threshold,
......
...@@ -216,11 +216,21 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -216,11 +216,21 @@ class MaskRCNNModel(tf.keras.Model):
regression_weights, regression_weights,
bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred'])) bbox_per_class=(not self._config_dict['class_agnostic_bbox_pred']))
model_outputs.update({ model_outputs.update({
'detection_boxes': detections['detection_boxes'], 'cls_outputs': class_outputs,
'detection_scores': detections['detection_scores'], 'box_outputs': box_outputs,
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections'],
}) })
if self.detection_generator.get_config()['apply_nms']:
model_outputs.update({
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections']
})
else:
model_outputs.update({
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
})
if not self._include_mask: if not self._include_mask:
return model_outputs return model_outputs
......
...@@ -148,13 +148,22 @@ class RetinaNetModel(tf.keras.Model): ...@@ -148,13 +148,22 @@ class RetinaNetModel(tf.keras.Model):
final_results = self.detection_generator( final_results = self.detection_generator(
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes) raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
outputs = { outputs = {
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'num_detections': final_results['num_detections'],
'cls_outputs': raw_scores, 'cls_outputs': raw_scores,
'box_outputs': raw_boxes, 'box_outputs': raw_boxes,
} }
if self.detection_generator.get_config()['apply_nms']:
outputs.update({
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'num_detections': final_results['num_detections']
})
else:
outputs.update({
'decoded_boxes': final_results['decoded_boxes'],
'decoded_box_scores': final_results['decoded_box_scores']
})
if raw_attributes: if raw_attributes:
outputs.update({ outputs.update({
'att_outputs': raw_attributes, 'att_outputs': raw_attributes,
......
...@@ -126,14 +126,23 @@ class DetectionModule(export_base.ExportModule): ...@@ -126,14 +126,23 @@ class DetectionModule(export_base.ExportModule):
anchor_boxes=anchor_boxes, anchor_boxes=anchor_boxes,
training=False) training=False)
final_outputs = { if self.params.task.model.detection_generator.apply_nms:
'detection_boxes': detections['detection_boxes'], final_outputs = {
'detection_scores': detections['detection_scores'], 'detection_boxes': detections['detection_boxes'],
'detection_classes': detections['detection_classes'], 'detection_scores': detections['detection_scores'],
'num_detections': detections['num_detections'], 'detection_classes': detections['detection_classes'],
'image_info': image_info 'num_detections': detections['num_detections']
} }
else:
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores'],
'cls_outputs': detections['cls_outputs'],
'box_outputs': detections['box_outputs']
}
if 'detection_masks' in detections.keys(): if 'detection_masks' in detections.keys():
final_outputs['detection_masks'] = detections['detection_masks'] final_outputs['detection_masks'] = detections['detection_masks']
final_outputs.update({'image_info': image_info})
return final_outputs return final_outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment