Unverified Commit 085e46f8 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Panoptic Quality Evaluator Changes (#4)

* added flag to disable generation of panoptic masks

* added panoptic_quality_evaluator to task

* added `PanopticQualityEvaluator` config

* added default strategy to testcase

* import panoptic_maskrcnn project

* fixed shapes of panoptic masks
parent ca552843
...@@ -12,3 +12,4 @@ ...@@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from official.vision.beta.projects import panoptic_maskrcnn
...@@ -93,6 +93,7 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN): ...@@ -93,6 +93,7 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN):
include_mask = True include_mask = True
shared_backbone: bool = True shared_backbone: bool = True
shared_decoder: bool = True shared_decoder: bool = True
generate_panoptic_masks: bool = True
panoptic_segmentation_generator: PanopticSegmentationGenerator = \ panoptic_segmentation_generator: PanopticSegmentationGenerator = \
PanopticSegmentationGenerator() PanopticSegmentationGenerator()
...@@ -109,6 +110,16 @@ class Losses(maskrcnn.Losses): ...@@ -109,6 +110,16 @@ class Losses(maskrcnn.Losses):
semantic_segmentation_weight: float = 1.0 semantic_segmentation_weight: float = 1.0
@dataclasses.dataclass
class PanopticQualityEvaluator(hyperparams.Config):
"""Panoptic Quality Evaluator config."""
num_categories: int = 2
ignored_label: int = 0
max_instances_per_category: int = 100
offset: int = 256 * 256 * 256
is_thing: List[float] = dataclasses.field(
default_factory=list)
@dataclasses.dataclass @dataclasses.dataclass
class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
"""Panoptic Mask R-CNN task config.""" """Panoptic Mask R-CNN task config."""
...@@ -130,7 +141,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -130,7 +141,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
# 'all': Initialize all modules # 'all': Initialize all modules
init_checkpoint_modules: Optional[List[str]] = dataclasses.field( init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
default_factory=list) default_factory=list)
evaluate_panoptic_quality: bool = True
panoptic_quality_evaluator: PanopticQualityEvaluator = PanopticQualityEvaluator() # pylint: disable=line-too-long
@exp_factory.register_config_factory('panoptic_maskrcnn_resnetfpn_coco') @exp_factory.register_config_factory('panoptic_maskrcnn_resnetfpn_coco')
def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
...@@ -149,8 +161,14 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -149,8 +161,14 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
# and map all thing categories to id=1, the remaining 109 stuff categories # and map all thing categories to id=1, the remaining 109 stuff categories
# are shifted by an offset=90 given by num_thing classes - 1. This shifting # are shifted by an offset=90 given by num_thing classes - 1. This shifting
# will make all the stuff categories begin from id=2 and end at id=110 # will make all the stuff categories begin from id=2 and end at id=110
num_panoptic_categories = 201
num_thing_categories = 91
num_semantic_segmentation_classes = 111 num_semantic_segmentation_classes = 111
is_thing = [False]
for idx in range(1, num_panoptic_categories):
is_thing.append(True if idx <= num_thing_categories else False)
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=PanopticMaskRCNNTask( task=PanopticMaskRCNNTask(
...@@ -177,7 +195,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -177,7 +195,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
global_batch_size=eval_batch_size, global_batch_size=eval_batch_size,
drop_remainder=False), drop_remainder=False),
annotation_file=os.path.join(_COCO_INPUT_PATH_BASE, annotation_file=os.path.join(_COCO_INPUT_PATH_BASE,
'instances_val2017.json')), 'instances_val2017.json'),
panoptic_quality_evaluator=PanopticQualityEvaluator(
num_categories=num_panoptic_categories,
ignored_label=0,
is_thing=is_thing)),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
train_steps=22500, train_steps=22500,
validation_steps=validation_steps, validation_steps=validation_steps,
......
...@@ -330,6 +330,9 @@ class Parser(maskrcnn_input.Parser): ...@@ -330,6 +330,9 @@ class Parser(maskrcnn_input.Parser):
data['groundtruth_panoptic_instance_mask'], data['groundtruth_panoptic_instance_mask'],
self._panoptic_ignore_label, image_info) self._panoptic_ignore_label, image_info)
panoptic_category_mask = panoptic_category_mask[:, :, 0]
panoptic_instance_mask = panoptic_instance_mask[:, :, 0]
labels['groundtruths'].update({ labels['groundtruths'].update({
'gt_panoptic_category_mask': panoptic_category_mask, 'gt_panoptic_category_mask': panoptic_category_mask,
'gt_panoptic_instance_mask': panoptic_instance_mask}) 'gt_panoptic_instance_mask': panoptic_instance_mask})
......
...@@ -93,17 +93,21 @@ def build_panoptic_maskrcnn( ...@@ -93,17 +93,21 @@ def build_panoptic_maskrcnn(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
if model_config.generate_panoptic_masks:
max_num_detections = model_config.detection_generator.max_num_detections max_num_detections = model_config.detection_generator.max_num_detections
mask_binarize_threshold = postprocessing_config.mask_binarize_threshold
panoptic_segmentation_generator_obj = \ panoptic_segmentation_generator_obj = \
panoptic_segmentation_generator.PanopticSegmentationGenerator( panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=postprocessing_config.output_size, output_size=postprocessing_config.output_size,
max_num_detections=max_num_detections, max_num_detections=max_num_detections,
stuff_classes_offset=postprocessing_config.stuff_classes_offset, stuff_classes_offset=postprocessing_config.stuff_classes_offset,
mask_binarize_threshold=postprocessing_config.mask_binarize_threshold, mask_binarize_threshold=mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold, score_threshold=postprocessing_config.score_threshold,
things_class_label=postprocessing_config.things_class_label, things_class_label=postprocessing_config.things_class_label,
void_class_label=postprocessing_config.void_class_label, void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id) void_instance_id=postprocessing_config.void_instance_id)
else:
panoptic_segmentation_generator_obj = None
# Combines maskrcnn, and segmentation models to build panoptic segmentation # Combines maskrcnn, and segmentation models to build panoptic segmentation
# model. # model.
......
...@@ -52,6 +52,7 @@ class PanopticSegmentationGeneratorTest( ...@@ -52,6 +52,7 @@ class PanopticSegmentationGeneratorTest(
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
])) ]))
......
...@@ -25,7 +25,8 @@ from official.vision.beta.modeling import maskrcnn_model ...@@ -25,7 +25,8 @@ from official.vision.beta.modeling import maskrcnn_model
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model.""" """The Panoptic Segmentation model."""
def __init__(self, def __init__(
self,
backbone: tf.keras.Model, backbone: tf.keras.Model,
decoder: tf.keras.Model, decoder: tf.keras.Model,
rpn_head: tf.keras.layers.Layer, rpn_head: tf.keras.layers.Layer,
...@@ -36,7 +37,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -36,7 +37,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
List[tf.keras.layers.Layer]], List[tf.keras.layers.Layer]],
roi_aligner: tf.keras.layers.Layer, roi_aligner: tf.keras.layers.Layer,
detection_generator: tf.keras.layers.Layer, detection_generator: tf.keras.layers.Layer,
panoptic_segmentation_generator: tf.keras.layers.Layer, panoptic_segmentation_generator: Optional[tf.keras.layers.Layer] = None,
mask_head: Optional[tf.keras.layers.Layer] = None, mask_head: Optional[tf.keras.layers.Layer] = None,
mask_sampler: Optional[tf.keras.layers.Layer] = None, mask_sampler: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None, mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
...@@ -120,10 +121,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -120,10 +121,13 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
self._config_dict.update({ self._config_dict.update({
'segmentation_backbone': segmentation_backbone, 'segmentation_backbone': segmentation_backbone,
'segmentation_decoder': segmentation_decoder, 'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head, 'segmentation_head': segmentation_head
'panoptic_segmentation_generator': panoptic_segmentation_generator
}) })
if panoptic_segmentation_generator is not None:
self._config_dict.update(
{'panoptic_segmentation_generator': panoptic_segmentation_generator})
if not self._include_mask: if not self._include_mask:
raise ValueError( raise ValueError(
'`mask_head` needs to be provided for Panoptic Mask R-CNN.') '`mask_head` needs to be provided for Panoptic Mask R-CNN.')
...@@ -172,11 +176,9 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -172,11 +176,9 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
'segmentation_outputs': segmentation_outputs, 'segmentation_outputs': segmentation_outputs,
}) })
if not training: if not training and self.panoptic_segmentation_generator is not None:
panoptic_outputs = self.panoptic_segmentation_generator(model_outputs) panoptic_outputs = self.panoptic_segmentation_generator(model_outputs)
model_outputs.update({ model_outputs.update({'panoptic_outputs': panoptic_outputs})
'panoptic_outputs': panoptic_outputs
})
return model_outputs return model_outputs
......
...@@ -174,9 +174,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -174,9 +174,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
shared_backbone=[True, False], shared_backbone=[True, False],
shared_decoder=[True, False], shared_decoder=[True, False],
training=[True, False], training=[True, False],
)) generate_panoptic_masks=[True, False]))
def test_forward(self, strategy, training, def test_forward(self, strategy, training,
shared_backbone, shared_decoder): shared_backbone, shared_decoder,
generate_panoptic_masks):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 3
max_level = 4 max_level = 4
...@@ -228,11 +229,16 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -228,11 +229,16 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_cascade.append(roi_sampler_obj) roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner() roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator() detection_generator_obj = detection_generator.DetectionGenerator()
if generate_panoptic_masks:
panoptic_segmentation_generator_obj = \ panoptic_segmentation_generator_obj = \
panoptic_segmentation_generator.PanopticSegmentationGenerator( panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=list(image_size), output_size=list(image_size),
max_num_detections=100, max_num_detections=100,
stuff_classes_offset=90) stuff_classes_offset=90)
else:
panoptic_segmentation_generator_obj = None
mask_head = instance_heads.MaskHead( mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2) num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -311,19 +317,23 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -311,19 +317,23 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertIn('num_detections', results) self.assertIn('num_detections', results)
self.assertIn('detection_masks', results) self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results) self.assertIn('segmentation_outputs', results)
self.assertIn('panoptic_outputs', results)
self.assertIn('category_mask', results['panoptic_outputs'])
self.assertIn('instance_mask', results['panoptic_outputs'])
self.assertAllEqual( self.assertAllEqual(
[2, image_size[0] // (2**level), image_size[1] // (2**level), 2], [2, image_size[0] // (2**level), image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape) results['segmentation_outputs'].numpy().shape)
if generate_panoptic_masks:
self.assertIn('panoptic_outputs', results)
self.assertIn('category_mask', results['panoptic_outputs'])
self.assertIn('instance_mask', results['panoptic_outputs'])
self.assertAllEqual( self.assertAllEqual(
[2, image_size[0], image_size[1]], [2, image_size[0], image_size[1]],
results['panoptic_outputs']['category_mask'].numpy().shape) results['panoptic_outputs']['category_mask'].numpy().shape)
self.assertAllEqual( self.assertAllEqual(
[2, image_size[0], image_size[1]], [2, image_size[0], image_size[1]],
results['panoptic_outputs']['instance_mask'].numpy().shape) results['panoptic_outputs']['instance_mask'].numpy().shape)
else:
self.assertNotIn('panoptic_outputs', results)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
......
...@@ -23,6 +23,7 @@ from official.core import task_factory ...@@ -23,6 +23,7 @@ from official.core import task_factory
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import segmentation_metrics from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.evaluation import panoptic_quality_evaluator
from official.vision.beta.losses import segmentation_losses from official.vision.beta.losses import segmentation_losses
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as exp_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as exp_cfg
from official.vision.beta.projects.panoptic_maskrcnn.dataloaders import panoptic_maskrcnn_input from official.vision.beta.projects.panoptic_maskrcnn.dataloaders import panoptic_maskrcnn_input
...@@ -208,6 +209,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -208,6 +209,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
tf.keras.metrics.Metric]: tf.keras.metrics.Metric]:
"""Build detection metrics.""" """Build detection metrics."""
metrics = [] metrics = []
num_segmentation_classes = \
self.task_config.model.segmentation_model.num_classes
if training: if training:
metric_names = [ metric_names = [
'total_loss', 'total_loss',
...@@ -226,7 +229,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -226,7 +229,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
if self.task_config.segmentation_evaluation.report_train_mean_iou: if self.task_config.segmentation_evaluation.report_train_mean_iou:
self.segmentation_train_mean_iou = segmentation_metrics.MeanIoU( self.segmentation_train_mean_iou = segmentation_metrics.MeanIoU(
name='train_mean_iou', name='train_mean_iou',
num_classes=self.task_config.model.num_classes, num_classes=num_segmentation_classes,
rescale_predictions=False, rescale_predictions=False,
dtype=tf.float32) dtype=tf.float32)
...@@ -240,9 +243,24 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -240,9 +243,24 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_resize_eval_groundtruth) .segmentation_resize_eval_groundtruth)
self.segmentation_perclass_iou_metric = segmentation_metrics.PerClassIoU( self.segmentation_perclass_iou_metric = segmentation_metrics.PerClassIoU(
name='per_class_iou', name='per_class_iou',
num_classes=self.task_config.model.num_classes, num_classes=num_segmentation_classes,
rescale_predictions=rescale_predictions, rescale_predictions=rescale_predictions,
dtype=tf.float32) dtype=tf.float32)
if self.task_config.evaluate_panoptic_quality:
assert (
self.task_config.validation_data.parser.include_eval_masks,
'`include_eval_masks` should be set to True when computing '
'panoptic quality')
pq_config = self.task_config.panoptic_quality_evaluator
self.panoptic_quality_metric = \
panoptic_quality_evaluator.PanopticQualityEvaluator(
num_categories=pq_config.num_categories,
ignored_label=pq_config.ignored_label,
max_instances_per_category=pq_config.max_instances_per_category,
offset=pq_config.offset,
is_thing=pq_config.is_thing)
return metrics return metrics
def train_step(self, def train_step(self,
...@@ -360,6 +378,16 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -360,6 +378,16 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
segmentation_labels, segmentation_labels,
outputs['segmentation_outputs']) outputs['segmentation_outputs'])
}) })
if self.task_config.evaluate_panoptic_quality:
pq_metric_labels = {
'category_mask':
labels['groundtruths']['gt_panoptic_category_mask'],
'instance_mask':
labels['groundtruths']['gt_panoptic_instance_mask']
}
logs.update({
self.panoptic_quality_metric.name:
(pq_metric_labels, outputs['panoptic_outputs'])})
return logs return logs
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
...@@ -367,6 +395,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -367,6 +395,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.coco_metric.reset_states() self.coco_metric.reset_states()
self.segmentation_perclass_iou_metric.reset_states() self.segmentation_perclass_iou_metric.reset_states()
state = [self.coco_metric, self.segmentation_perclass_iou_metric] state = [self.coco_metric, self.segmentation_perclass_iou_metric]
if self.task_config.evaluate_panoptic_quality:
state += [self.panoptic_quality_metric]
self.coco_metric.update_state( self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0], step_outputs[self.coco_metric.name][0],
...@@ -374,6 +404,12 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -374,6 +404,12 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.segmentation_perclass_iou_metric.update_state( self.segmentation_perclass_iou_metric.update_state(
step_outputs[self.segmentation_perclass_iou_metric.name][0], step_outputs[self.segmentation_perclass_iou_metric.name][0],
step_outputs[self.segmentation_perclass_iou_metric.name][1]) step_outputs[self.segmentation_perclass_iou_metric.name][1])
if self.task_config.evaluate_panoptic_quality:
self.panoptic_quality_metric.update_state(
step_outputs[self.panoptic_quality_metric.name][0],
step_outputs[self.panoptic_quality_metric.name][1])
return state return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
...@@ -389,4 +425,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -389,4 +425,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
result.update({'segmentation_iou/class_{}'.format(i): value}) result.update({'segmentation_iou/class_{}'.format(i): value})
# Computes mean IoU # Computes mean IoU
result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()}) result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})
if self.task_config.evaluate_panoptic_quality:
for k, value in self.panoptic_quality_metric.result().items():
result['panoptic_quality/' + k] = value
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment