Commit 9bcbe962 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481231095
parent f23d7bc4
...@@ -244,7 +244,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -244,7 +244,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
dtype=tf.float32) dtype=tf.float32)
else: else:
self._build_coco_metrics() if self.task_config.use_coco_metrics:
self._build_coco_metrics()
rescale_predictions = (not self.task_config.validation_data.parser rescale_predictions = (not self.task_config.validation_data.parser
.segmentation_resize_eval_groundtruth) .segmentation_resize_eval_groundtruth)
...@@ -366,24 +367,25 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -366,24 +367,25 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
training=False) training=False)
logs = {self.loss: 0} logs = {self.loss: 0}
coco_model_outputs = { if self._task_config.use_coco_metrics:
'detection_masks': outputs['detection_masks'], coco_model_outputs = {
'detection_boxes': outputs['detection_boxes'], 'detection_masks': outputs['detection_masks'],
'detection_scores': outputs['detection_scores'], 'detection_boxes': outputs['detection_boxes'],
'detection_classes': outputs['detection_classes'], 'detection_scores': outputs['detection_scores'],
'num_detections': outputs['num_detections'], 'detection_classes': outputs['detection_classes'],
'source_id': labels['groundtruths']['source_id'], 'num_detections': outputs['num_detections'],
'image_info': labels['image_info'] 'source_id': labels['groundtruths']['source_id'],
} 'image_info': labels['image_info']
}
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
segmentation_labels = { segmentation_labels = {
'masks': labels['groundtruths']['gt_segmentation_mask'], 'masks': labels['groundtruths']['gt_segmentation_mask'],
'valid_masks': labels['groundtruths']['gt_segmentation_valid_mask'], 'valid_masks': labels['groundtruths']['gt_segmentation_valid_mask'],
'image_info': labels['image_info'] 'image_info': labels['image_info']
} }
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
self.segmentation_perclass_iou_metric.update_state( self.segmentation_perclass_iou_metric.update_state(
segmentation_labels, outputs['segmentation_outputs']) segmentation_labels, outputs['segmentation_outputs'])
...@@ -400,15 +402,18 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -400,15 +402,18 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
if state is None: if state is None:
self.coco_metric.reset_states()
self.segmentation_perclass_iou_metric.reset_states() self.segmentation_perclass_iou_metric.reset_states()
state = [self.coco_metric, self.segmentation_perclass_iou_metric] state = [self.segmentation_perclass_iou_metric]
if self.task_config.use_coco_metrics:
self.coco_metric.reset_states()
state.append(self.coco_metric)
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
state += [self.panoptic_quality_metric] self.panoptic_quality_metric.reset_states()
state.append(self.panoptic_quality_metric)
self.coco_metric.update_state( if self.task_config.use_coco_metrics:
step_outputs[self.coco_metric.name][0], self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1]) step_outputs[self.coco_metric.name][1])
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric.update_state( self.panoptic_quality_metric.update_state(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment