Unverified Commit 4c764049 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

removed config param `evaluate_panoptic_quality`

parent 910ccf1a
......@@ -141,7 +141,6 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
# 'all': Initialize all modules
init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
default_factory=list)
evaluate_panoptic_quality: bool = True
panoptic_quality_evaluator: PanopticQualityEvaluator = PanopticQualityEvaluator() # pylint: disable=line-too-long
@exp_factory.register_config_factory('panoptic_maskrcnn_resnetfpn_coco')
......
......@@ -246,7 +246,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
rescale_predictions=rescale_predictions,
dtype=tf.float32)
if self.task_config.evaluate_panoptic_quality:
if self.task_config.model.generate_panoptic_masks:
assert (
self.task_config.validation_data.parser.include_panoptic_masks,
'`include_panoptic_masks` should be set to True when computing '
......@@ -376,7 +376,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
segmentation_labels,
outputs['segmentation_outputs'])
})
if self.task_config.evaluate_panoptic_quality:
if self.task_config.model.generate_panoptic_masks:
pq_metric_labels = {
'category_mask':
labels['groundtruths']['gt_panoptic_category_mask'],
......@@ -393,7 +393,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.coco_metric.reset_states()
self.segmentation_perclass_iou_metric.reset_states()
state = [self.coco_metric, self.segmentation_perclass_iou_metric]
if self.task_config.evaluate_panoptic_quality:
if self.task_config.model.generate_panoptic_masks:
state += [self.panoptic_quality_metric]
self.coco_metric.update_state(
......@@ -403,7 +403,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
step_outputs[self.segmentation_perclass_iou_metric.name][0],
step_outputs[self.segmentation_perclass_iou_metric.name][1])
if self.task_config.evaluate_panoptic_quality:
if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric.update_state(
step_outputs[self.panoptic_quality_metric.name][0],
step_outputs[self.panoptic_quality_metric.name][1])
......@@ -424,7 +424,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
# Computes mean IoU
result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})
if self.task_config.evaluate_panoptic_quality:
if self.task_config.model.generate_panoptic_masks:
for k, value in self.panoptic_quality_metric.result().items():
result['panoptic_quality/' + k] = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment