Commit 1fed4144 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add model export config to RetinaNet, which optionally cast model outputs to...

Add model export config to RetinaNet, which optionally cast model outputs to floats, and normalize output box coordinates.

PiperOrigin-RevId: 381334851
parent b53e5dc0
......@@ -130,6 +130,13 @@ class RetinaNet(hyperparams.Config):
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class ExportConfig(hyperparams.Config):
output_normalized_coordinates: bool = False
cast_num_detections_to_float: bool = False
cast_detection_classes_to_float: bool = False
@dataclasses.dataclass
class RetinaNetTask(cfg.TaskConfig):
model: RetinaNet = RetinaNet()
......@@ -140,6 +147,7 @@ class RetinaNetTask(cfg.TaskConfig):
init_checkpoint_modules: str = 'all' # all or backbone
annotation_file: Optional[str] = None
per_category_metrics: bool = False
export_config: ExportConfig = ExportConfig()
@exp_factory.register_config_factory('retinanet')
......
......@@ -17,6 +17,7 @@
import tensorflow as tf
from cloud_tpu.models.detection.utils import box_utils
from official.vision.beta import configs
from official.vision.beta.modeling import factory
from official.vision.beta.ops import anchor
......@@ -130,6 +131,28 @@ class DetectionModule(export_base.ExportModule):
training=False)
if self.params.task.model.detection_generator.apply_nms:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
detection_boxes = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
detections['detection_boxes'] = box_utils.normalize_boxes(
detection_boxes, image_info[:, 0:1, :])
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if export_config.cast_num_detections_to_float:
detections['num_detections'] = tf.cast(
detections['num_detections'], dtype=tf.float32)
if export_config.cast_detection_classes_to_float:
detections['detection_classes'] = tf.cast(
detections['detection_classes'], dtype=tf.float32)
final_outputs = {
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment