Commit e653807f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10203 from srihari-humbarwadi:panoptic-segmentation

PiperOrigin-RevId: 399527495
parents cfe32967 6112b4ce
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -20,6 +20,7 @@ from typing import List, Optional ...@@ -20,6 +20,7 @@ from typing import List, Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
...@@ -72,6 +73,17 @@ class DataConfig(maskrcnn.DataConfig): ...@@ -72,6 +73,17 @@ class DataConfig(maskrcnn.DataConfig):
parser: Parser = Parser() parser: Parser = Parser()
@dataclasses.dataclass
class PanopticSegmentationGenerator(hyperparams.Config):
output_size: List[int] = dataclasses.field(
default_factory=list)
mask_binarize_threshold: float = 0.5
score_threshold: float = 0.05
things_class_label: int = 1
void_class_label: int = 0
void_instance_id: int = 0
@dataclasses.dataclass @dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN): class PanopticMaskRCNN(maskrcnn.MaskRCNN):
"""Panoptic Mask R-CNN model config.""" """Panoptic Mask R-CNN model config."""
...@@ -80,6 +92,9 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN): ...@@ -80,6 +92,9 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN):
include_mask = True include_mask = True
shared_backbone: bool = True shared_backbone: bool = True
shared_decoder: bool = True shared_decoder: bool = True
stuff_classes_offset: int = 0
generate_panoptic_masks: bool = True
panoptic_segmentation_generator: PanopticSegmentationGenerator = PanopticSegmentationGenerator() # pylint:disable=line-too-long
@dataclasses.dataclass @dataclasses.dataclass
...@@ -94,6 +109,17 @@ class Losses(maskrcnn.Losses): ...@@ -94,6 +109,17 @@ class Losses(maskrcnn.Losses):
semantic_segmentation_weight: float = 1.0 semantic_segmentation_weight: float = 1.0
@dataclasses.dataclass
class PanopticQualityEvaluator(hyperparams.Config):
"""Panoptic Quality Evaluator config."""
num_categories: int = 2
ignored_label: int = 0
max_instances_per_category: int = 100
offset: int = 256 * 256 * 256
is_thing: List[float] = dataclasses.field(
default_factory=list)
@dataclasses.dataclass @dataclasses.dataclass
class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
"""Panoptic Mask R-CNN task config.""" """Panoptic Mask R-CNN task config."""
...@@ -115,6 +141,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -115,6 +141,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
# 'all': Initialize all modules # 'all': Initialize all modules
init_checkpoint_modules: Optional[List[str]] = dataclasses.field( init_checkpoint_modules: Optional[List[str]] = dataclasses.field(
default_factory=list) default_factory=list)
panoptic_quality_evaluator: PanopticQualityEvaluator = PanopticQualityEvaluator() # pylint: disable=line-too-long
@exp_factory.register_config_factory('panoptic_maskrcnn_resnetfpn_coco') @exp_factory.register_config_factory('panoptic_maskrcnn_resnetfpn_coco')
...@@ -125,6 +152,22 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -125,6 +152,22 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size
# coco panoptic dataset has category ids ranging from [0-200] inclusive.
# 0 is not used and represents the background class
# ids 1-91 represent thing categories (91)
# ids 92-200 represent stuff categories (109)
# for the segmentation task, we continue using id=0 for the background
# and map all thing categories to id=1, the remaining 109 stuff categories
# are shifted by an offset=90 given by num_thing classes - 1. This shifting
# will make all the stuff categories begin from id=2 and end at id=110
num_panoptic_categories = 201
num_thing_categories = 91
num_semantic_segmentation_classes = 111
is_thing = [False]
for idx in range(1, num_panoptic_categories):
is_thing.append(True if idx <= num_thing_categories else False)
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
task=PanopticMaskRCNNTask( task=PanopticMaskRCNNTask(
...@@ -132,8 +175,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -132,8 +175,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
init_checkpoint_modules=['backbone'], init_checkpoint_modules=['backbone'],
model=PanopticMaskRCNN( model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024]),
stuff_classes_offset=90,
segmentation_model=SEGMENTATION_MODEL( segmentation_model=SEGMENTATION_MODEL(
num_classes=91, num_classes=num_semantic_segmentation_classes,
head=SEGMENTATION_HEAD(level=3))), head=SEGMENTATION_HEAD(level=3))),
losses=Losses(l2_weight_decay=0.00004), losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig( train_data=DataConfig(
...@@ -148,7 +194,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -148,7 +194,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
global_batch_size=eval_batch_size, global_batch_size=eval_batch_size,
drop_remainder=False), drop_remainder=False),
annotation_file=os.path.join(_COCO_INPUT_PATH_BASE, annotation_file=os.path.join(_COCO_INPUT_PATH_BASE,
'instances_val2017.json')), 'instances_val2017.json'),
panoptic_quality_evaluator=PanopticQualityEvaluator(
num_categories=num_panoptic_categories,
ignored_label=0,
is_thing=is_thing)),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
train_steps=22500, train_steps=22500,
validation_steps=validation_steps, validation_steps=validation_steps,
......
...@@ -330,6 +330,9 @@ class Parser(maskrcnn_input.Parser): ...@@ -330,6 +330,9 @@ class Parser(maskrcnn_input.Parser):
data['groundtruth_panoptic_instance_mask'], data['groundtruth_panoptic_instance_mask'],
self._panoptic_ignore_label, image_info) self._panoptic_ignore_label, image_info)
panoptic_category_mask = panoptic_category_mask[:, :, 0]
panoptic_instance_mask = panoptic_instance_mask[:, :, 0]
labels['groundtruths'].update({ labels['groundtruths'].update({
'gt_panoptic_category_mask': panoptic_category_mask, 'gt_panoptic_category_mask': panoptic_category_mask,
'gt_panoptic_instance_mask': panoptic_instance_mask}) 'gt_panoptic_instance_mask': panoptic_instance_mask})
......
...@@ -22,6 +22,7 @@ from official.vision.beta.modeling.decoders import factory as decoder_factory ...@@ -22,6 +22,7 @@ from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
def build_panoptic_maskrcnn( def build_panoptic_maskrcnn(
...@@ -73,6 +74,7 @@ def build_panoptic_maskrcnn( ...@@ -73,6 +74,7 @@ def build_panoptic_maskrcnn(
segmentation_head_config = segmentation_config.head segmentation_head_config = segmentation_config.head
detection_head_config = model_config.detection_head detection_head_config = model_config.detection_head
postprocessing_config = model_config.panoptic_segmentation_generator
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=segmentation_config.num_classes, num_classes=segmentation_config.num_classes,
...@@ -90,6 +92,21 @@ def build_panoptic_maskrcnn( ...@@ -90,6 +92,21 @@ def build_panoptic_maskrcnn(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
if model_config.generate_panoptic_masks:
max_num_detections = model_config.detection_generator.max_num_detections
mask_binarize_threshold = postprocessing_config.mask_binarize_threshold
panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=postprocessing_config.output_size,
max_num_detections=max_num_detections,
stuff_classes_offset=model_config.stuff_classes_offset,
mask_binarize_threshold=mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold,
things_class_label=postprocessing_config.things_class_label,
void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id)
else:
panoptic_segmentation_generator_obj = None
# Combines maskrcnn, and segmentation models to build panoptic segmentation # Combines maskrcnn, and segmentation models to build panoptic segmentation
# model. # model.
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -101,6 +118,7 @@ def build_panoptic_maskrcnn( ...@@ -101,6 +118,7 @@ def build_panoptic_maskrcnn(
roi_sampler=maskrcnn_model.roi_sampler, roi_sampler=maskrcnn_model.roi_sampler,
roi_aligner=maskrcnn_model.roi_aligner, roi_aligner=maskrcnn_model.roi_aligner,
detection_generator=maskrcnn_model.detection_generator, detection_generator=maskrcnn_model.detection_generator,
panoptic_segmentation_generator=panoptic_segmentation_generator_obj,
mask_head=maskrcnn_model.mask_head, mask_head=maskrcnn_model.mask_head,
mask_sampler=maskrcnn_model.mask_sampler, mask_sampler=maskrcnn_model.mask_sampler,
mask_roi_aligner=maskrcnn_model.mask_roi_aligner, mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definition for postprocessing layer to genrate panoptic segmentations."""
from typing import List
import tensorflow as tf
class PanopticSegmentationGenerator(tf.keras.layers.Layer):
"""Panoptic segmentation generator layer."""
def __init__(
self,
output_size: List[int],
max_num_detections: int,
stuff_classes_offset: int,
mask_binarize_threshold: float = 0.5,
score_threshold: float = 0.05,
things_class_label: int = 1,
void_class_label: int = 0,
void_instance_id: int = -1,
**kwargs):
"""Generates panoptic segmentation masks.
Args:
output_size: A `List` of integers that represent the height and width of
the output mask.
max_num_detections: `int` for maximum number of detections.
stuff_classes_offset: An `int` that is added to the output of the
semantic segmentation mask to make sure that the stuff class ids do not
ovelap with the thing class ids of the MaskRCNN outputs.
mask_binarize_threshold: A `float`
score_threshold: A `float` representing the threshold for deciding
when to remove objects based on score.
things_class_label: An `int` that represents a single merged category of
all thing classes in the semantic segmentation output.
void_class_label: An `int` that is used to represent empty or unlabelled
regions of the mask
void_instance_id: An `int` that is used to denote regions that are not
assigned to any thing class. That is, void_instance_id are assigned to
both stuff regions and empty regions.
**kwargs: additional kewargs arguments.
"""
self._output_size = output_size
self._max_num_detections = max_num_detections
self._stuff_classes_offset = stuff_classes_offset
self._mask_binarize_threshold = mask_binarize_threshold
self._score_threshold = score_threshold
self._things_class_label = things_class_label
self._void_class_label = void_class_label
self._void_instance_id = void_instance_id
self._config_dict = {
'output_size': output_size,
'max_num_detections': max_num_detections,
'stuff_classes_offset': stuff_classes_offset,
'mask_binarize_threshold': mask_binarize_threshold,
'score_threshold': score_threshold,
'things_class_label': things_class_label,
'void_class_label': void_class_label,
'void_instance_id': void_instance_id
}
super(PanopticSegmentationGenerator, self).__init__(**kwargs)
def _paste_mask(self, box, mask):
pasted_mask = tf.ones(
self._output_size + [1], dtype=mask.dtype) * self._void_class_label
ymin = box[0]
xmin = box[1]
ymax = tf.clip_by_value(box[2] + 1, 0, self._output_size[0])
xmax = tf.clip_by_value(box[3] + 1, 0, self._output_size[1])
box_height = ymax - ymin
box_width = xmax - xmin
# resize mask to match the shape of the instance bounding box
resized_mask = tf.image.resize(
mask,
size=(box_height, box_width),
method='nearest')
# paste resized mask on a blank mask that matches image shape
pasted_mask = tf.raw_ops.TensorStridedSliceUpdate(
input=pasted_mask,
begin=[ymin, xmin],
end=[ymax, xmax],
strides=[1, 1],
value=resized_mask)
return pasted_mask
def _generate_panoptic_masks(self, boxes, scores, classes, detections_masks,
segmentation_mask):
"""Generates panoptic masks for a single image.
This function implements the following steps to merge instance and semantic
segmentation masks described in https://arxiv.org/pdf/1901.02446.pdf
Steps:
1. resolving overlaps between different instances based on their
confidence scores
2. resolving overlaps between instance and semantic segmentation
outputs in favor of instances
3. removing any stuff regions labeled other or under a given area
threshold.
Args:
boxes: A `tf.Tensor` of shape [num_rois, 4], representing the bounding
boxes for detected objects.
scores: A `tf.Tensor` of shape [num_rois], representing the
confidence scores for each object.
classes: A `tf.Tensor` of shape [num_rois], representing the class
for each object.
detections_masks: A `tf.Tensor` of shape
[num_rois, mask_height, mask_width, 1], representing the cropped mask
for each object.
segmentation_mask: A `tf.Tensor` of shape [height, width], representing
the semantic segmentation output.
Returns:
Dict with the following keys:
- category_mask: A `tf.Tensor` for category masks.
- instance_mask: A `tf.Tensor for instance masks.
"""
# Offset stuff class predictions
segmentation_mask = tf.where(
tf.logical_or(
tf.equal(segmentation_mask, self._things_class_label),
tf.equal(segmentation_mask, self._void_class_label)),
segmentation_mask,
segmentation_mask + self._stuff_classes_offset
)
# sort instances by their scores
sorted_indices = tf.argsort(scores, direction='DESCENDING')
mask_shape = self._output_size + [1]
category_mask = tf.ones(mask_shape,
dtype=tf.float32) * self._void_class_label
instance_mask = tf.ones(
mask_shape, dtype=tf.float32) * self._void_instance_id
# filter instances with low confidence
sorted_scores = tf.sort(scores, direction='DESCENDING')
valid_indices = tf.where(sorted_scores > self._score_threshold)
# if no instance has sufficient confidence score, skip merging
# instance segmentation masks
if tf.shape(valid_indices)[0] > 0:
loop_end_idx = valid_indices[-1, 0] + 1
loop_end_idx = tf.minimum(
tf.cast(loop_end_idx, dtype=tf.int32),
self._max_num_detections)
# add things segmentation to panoptic masks
for i in range(loop_end_idx):
# we process instances in decending order, which will make sure
# the overlaps are resolved based on confidence score
instance_idx = sorted_indices[i]
pasted_mask = self._paste_mask(
box=boxes[instance_idx],
mask=detections_masks[instance_idx])
class_id = tf.cast(classes[instance_idx], dtype=tf.float32)
# convert sigmoid scores to binary values
binary_mask = tf.greater(
pasted_mask, self._mask_binarize_threshold)
# filter empty instance masks
if not tf.reduce_sum(tf.cast(binary_mask, tf.float32)) > 0:
continue
# fill empty regions in category_mask represented by
# void_class_label with class_id of the instance.
category_mask = tf.where(
tf.logical_and(
binary_mask, tf.equal(category_mask, self._void_class_label)),
tf.ones_like(category_mask) * class_id, category_mask)
# fill empty regions in the instance_mask represented by
# void_instance_id with the id of the instance, starting from 1
instance_mask = tf.where(
tf.logical_and(
binary_mask,
tf.equal(instance_mask, self._void_instance_id)),
tf.ones_like(instance_mask) *
tf.cast(instance_idx + 1, tf.float32), instance_mask)
# add stuff segmentation labels to empty regions of category_mask.
# we ignore the pixels labelled as "things", since we get them from
# the instance masks.
# TODO(srihari, arashwan): Support filtering stuff classes based on area.
category_mask = tf.where(
tf.logical_and(
tf.equal(
category_mask, self._void_class_label),
tf.logical_and(
tf.not_equal(segmentation_mask, self._things_class_label),
tf.not_equal(segmentation_mask, self._void_class_label))),
segmentation_mask, category_mask)
results = {
'category_mask': category_mask[:, :, 0],
'instance_mask': instance_mask[:, :, 0]
}
return results
def call(self, inputs):
detections = inputs
batched_scores = detections['detection_scores']
batched_classes = detections['detection_classes']
batched_boxes = tf.cast(detections['detection_boxes'], dtype=tf.int32)
batched_detections_masks = tf.expand_dims(
detections['detection_masks'], axis=-1)
batched_segmentation_masks = tf.image.resize(
detections['segmentation_outputs'],
size=self._output_size,
method='bilinear')
batched_segmentation_masks = tf.expand_dims(tf.cast(
tf.argmax(batched_segmentation_masks, axis=-1),
dtype=tf.float32), axis=-1)
panoptic_masks = tf.map_fn(
fn=lambda x: self._generate_panoptic_masks( # pylint:disable=g-long-lambda
x[0], x[1], x[2], x[3], x[4]),
elems=(
batched_boxes,
batched_scores,
batched_classes,
batched_detections_masks,
batched_segmentation_masks),
fn_output_signature={
'category_mask': tf.float32,
'instance_mask': tf.float32
})
for k, v in panoptic_masks.items():
panoptic_masks[k] = tf.cast(v, dtype=tf.int32)
return panoptic_masks
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for panoptic_segmentation_generator.py."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
PANOPTIC_SEGMENTATION_GENERATOR = panoptic_segmentation_generator.PanopticSegmentationGenerator
class PanopticSegmentationGeneratorTest(
parameterized.TestCase, tf.test.TestCase):
def test_serialize_deserialize(self):
config = {
'output_size': [640, 640],
'max_num_detections': 100,
'stuff_classes_offset': 90,
'mask_binarize_threshold': 0.5,
'score_threshold': 0.005,
'things_class_label': 1,
'void_class_label': 0,
'void_instance_id': -1
}
generator = PANOPTIC_SEGMENTATION_GENERATOR(**config)
expected_config = dict(config)
self.assertEqual(generator.get_config(), expected_config)
new_generator = PANOPTIC_SEGMENTATION_GENERATOR.from_config(
generator.get_config())
self.assertAllEqual(generator.get_config(), new_generator.get_config())
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy_gpu,
]))
def test_outputs(self, strategy):
# 0 represents the void class label
thing_class_ids = [0, 1, 2, 3, 4]
stuff_class_ids = [0, 5, 6, 7, 8, 9, 10]
all_class_ids = set(thing_class_ids + stuff_class_ids)
num_thing_classes = len(thing_class_ids)
num_stuff_classes = len(stuff_class_ids)
num_classes_for_segmentation = num_stuff_classes + 1
# all thing classes are mapped to class_id=1, stuff class ids are offset
# such that the stuff class_ids start from 2, this means the semantic
# segmentation head will have ground truths with class_ids belonging to
# [0, 1, 2, 3, 4, 5, 6, 7]
config = {
'output_size': [640, 640],
'max_num_detections': 100,
'stuff_classes_offset': 3,
'mask_binarize_threshold': 0.5,
'score_threshold': 0.005,
'things_class_label': 1,
'void_class_label': 0,
'void_instance_id': -1
}
generator = PANOPTIC_SEGMENTATION_GENERATOR(**config)
crop_height = 112
crop_width = 112
boxes = tf.constant([[
[167, 398, 342, 619],
[192, 171, 363, 449],
[211, 1, 382, 74]
]])
num_detections = boxes.get_shape().as_list()[1]
scores = tf.random.uniform([1, num_detections], 0, 1)
classes = tf.random.uniform(
[1, num_detections],
1, num_thing_classes, dtype=tf.int32)
masks = tf.random.normal(
[1, num_detections, crop_height, crop_width])
segmentation_mask = tf.random.uniform(
[1, *config['output_size']],
0, num_classes_for_segmentation, dtype=tf.int32)
segmentation_mask_one_hot = tf.one_hot(
segmentation_mask, depth=num_stuff_classes + 1)
inputs = {
'detection_boxes': boxes,
'detection_scores': scores,
'detection_classes': classes,
'detection_masks': masks,
'num_detections': tf.constant([num_detections]),
'segmentation_outputs': segmentation_mask_one_hot
}
def _run(inputs):
return generator(inputs=inputs)
@tf.function
def _distributed_run(inputs):
outputs = strategy.run(_run, args=((inputs,)))
return strategy.gather(outputs, axis=0)
outputs = _distributed_run(inputs)
self.assertIn('category_mask', outputs)
self.assertIn('instance_mask', outputs)
self.assertAllEqual(
outputs['category_mask'][0].get_shape().as_list(),
config['output_size'])
self.assertAllEqual(
outputs['instance_mask'][0].get_shape().as_list(),
config['output_size'])
for category_id in np.unique(outputs['category_mask']):
self.assertIn(category_id, all_class_ids)
if __name__ == '__main__':
tf.test.main()
...@@ -25,31 +25,33 @@ from official.vision.beta.modeling import maskrcnn_model ...@@ -25,31 +25,33 @@ from official.vision.beta.modeling import maskrcnn_model
class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
"""The Panoptic Segmentation model.""" """The Panoptic Segmentation model."""
def __init__(self, def __init__(
backbone: tf.keras.Model, self,
decoder: tf.keras.Model, backbone: tf.keras.Model,
rpn_head: tf.keras.layers.Layer, decoder: tf.keras.Model,
detection_head: Union[tf.keras.layers.Layer, rpn_head: tf.keras.layers.Layer,
List[tf.keras.layers.Layer]], detection_head: Union[tf.keras.layers.Layer,
roi_generator: tf.keras.layers.Layer, List[tf.keras.layers.Layer]],
roi_sampler: Union[tf.keras.layers.Layer, roi_generator: tf.keras.layers.Layer,
List[tf.keras.layers.Layer]], roi_sampler: Union[tf.keras.layers.Layer,
roi_aligner: tf.keras.layers.Layer, List[tf.keras.layers.Layer]],
detection_generator: tf.keras.layers.Layer, roi_aligner: tf.keras.layers.Layer,
mask_head: Optional[tf.keras.layers.Layer] = None, detection_generator: tf.keras.layers.Layer,
mask_sampler: Optional[tf.keras.layers.Layer] = None, panoptic_segmentation_generator: Optional[tf.keras.layers.Layer] = None,
mask_roi_aligner: Optional[tf.keras.layers.Layer] = None, mask_head: Optional[tf.keras.layers.Layer] = None,
segmentation_backbone: Optional[tf.keras.Model] = None, mask_sampler: Optional[tf.keras.layers.Layer] = None,
segmentation_decoder: Optional[tf.keras.Model] = None, mask_roi_aligner: Optional[tf.keras.layers.Layer] = None,
segmentation_head: tf.keras.layers.Layer = None, segmentation_backbone: Optional[tf.keras.Model] = None,
class_agnostic_bbox_pred: bool = False, segmentation_decoder: Optional[tf.keras.Model] = None,
cascade_class_ensemble: bool = False, segmentation_head: tf.keras.layers.Layer = None,
min_level: Optional[int] = None, class_agnostic_bbox_pred: bool = False,
max_level: Optional[int] = None, cascade_class_ensemble: bool = False,
num_scales: Optional[int] = None, min_level: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None, max_level: Optional[int] = None,
anchor_size: Optional[float] = None, # pytype: disable=annotation-type-mismatch # typed-keras num_scales: Optional[int] = None,
**kwargs): aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
**kwargs):
"""Initializes the Panoptic Mask R-CNN model. """Initializes the Panoptic Mask R-CNN model.
Args: Args:
...@@ -62,6 +64,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -62,6 +64,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
detection heads. detection heads.
roi_aligner: the ROI aligner. roi_aligner: the ROI aligner.
detection_generator: the detection generator. detection_generator: the detection generator.
panoptic_segmentation_generator: the panoptic segmentation generator that
is used to merge instance and semantic segmentation masks.
mask_head: the mask head. mask_head: the mask head.
mask_sampler: the mask sampler. mask_sampler: the mask sampler.
mask_roi_aligner: the ROI alginer for mask prediction. mask_roi_aligner: the ROI alginer for mask prediction.
...@@ -120,6 +124,10 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -120,6 +124,10 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
'segmentation_head': segmentation_head 'segmentation_head': segmentation_head
}) })
if panoptic_segmentation_generator is not None:
self._config_dict.update(
{'panoptic_segmentation_generator': panoptic_segmentation_generator})
if not self._include_mask: if not self._include_mask:
raise ValueError( raise ValueError(
'`mask_head` needs to be provided for Panoptic Mask R-CNN.') '`mask_head` needs to be provided for Panoptic Mask R-CNN.')
...@@ -131,6 +139,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -131,6 +139,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
self.segmentation_backbone = segmentation_backbone self.segmentation_backbone = segmentation_backbone
self.segmentation_decoder = segmentation_decoder self.segmentation_decoder = segmentation_decoder
self.segmentation_head = segmentation_head self.segmentation_head = segmentation_head
self.panoptic_segmentation_generator = panoptic_segmentation_generator
def call(self, def call(self,
images: tf.Tensor, images: tf.Tensor,
...@@ -167,6 +176,10 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -167,6 +176,10 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
'segmentation_outputs': segmentation_outputs, 'segmentation_outputs': segmentation_outputs,
}) })
if not training and self.panoptic_segmentation_generator is not None:
panoptic_outputs = self.panoptic_segmentation_generator(model_outputs)
model_outputs.update({'panoptic_outputs': panoptic_outputs})
return model_outputs return model_outputs
@property @property
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tests for panoptic_maskrcnn_model.py.""" """Tests for panoptic_maskrcnn_model.py."""
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -35,6 +34,7 @@ from official.vision.beta.modeling.layers import roi_generator ...@@ -35,6 +34,7 @@ from official.vision.beta.modeling.layers import roi_generator
from official.vision.beta.modeling.layers import roi_sampler from official.vision.beta.modeling.layers import roi_sampler
from official.vision.beta.ops import anchor from official.vision.beta.ops import anchor
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
...@@ -99,6 +99,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -99,6 +99,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj = roi_sampler.ROISampler() roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner() roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator() detection_generator_obj = detection_generator.DetectionGenerator()
panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=[image_size, image_size],
max_num_detections=100,
stuff_classes_offset=90)
mask_head = instance_heads.MaskHead( mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2) num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -131,6 +135,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -131,6 +135,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj, roi_sampler_obj,
roi_aligner_obj, roi_aligner_obj,
detection_generator_obj, detection_generator_obj,
panoptic_segmentation_generator_obj,
mask_head, mask_head,
mask_sampler_obj, mask_sampler_obj,
mask_roi_aligner_obj, mask_roi_aligner_obj,
...@@ -163,15 +168,16 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -163,15 +168,16 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
strategy_combinations.cloud_tpu_strategy, strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],
shared_backbone=[True, False], shared_backbone=[True, False],
shared_decoder=[True, False], shared_decoder=[True, False],
training=[True, False], training=[True, False],
)) generate_panoptic_masks=[True, False]))
def test_forward(self, strategy, training, def test_forward(self, strategy, training,
shared_backbone, shared_decoder): shared_backbone, shared_decoder,
generate_panoptic_masks):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 3
max_level = 4 max_level = 4
...@@ -223,6 +229,15 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -223,6 +229,15 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_cascade.append(roi_sampler_obj) roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner() roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator() detection_generator_obj = detection_generator.DetectionGenerator()
if generate_panoptic_masks:
panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=list(image_size),
max_num_detections=100,
stuff_classes_offset=90)
else:
panoptic_segmentation_generator_obj = None
mask_head = instance_heads.MaskHead( mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2) num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -255,6 +270,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -255,6 +270,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj, roi_sampler_obj,
roi_aligner_obj, roi_aligner_obj,
detection_generator_obj, detection_generator_obj,
panoptic_segmentation_generator_obj,
mask_head, mask_head,
mask_sampler_obj, mask_sampler_obj,
mask_roi_aligner_obj, mask_roi_aligner_obj,
...@@ -300,10 +316,24 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -300,10 +316,24 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertIn('num_detections', results) self.assertIn('num_detections', results)
self.assertIn('detection_masks', results) self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results) self.assertIn('segmentation_outputs', results)
self.assertAllEqual( self.assertAllEqual(
[2, image_size[0] // (2**level), image_size[1] // (2**level), 2], [2, image_size[0] // (2**level), image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape) results['segmentation_outputs'].numpy().shape)
if generate_panoptic_masks:
self.assertIn('panoptic_outputs', results)
self.assertIn('category_mask', results['panoptic_outputs'])
self.assertIn('instance_mask', results['panoptic_outputs'])
self.assertAllEqual(
[2, image_size[0], image_size[1]],
results['panoptic_outputs']['category_mask'].numpy().shape)
self.assertAllEqual(
[2, image_size[0], image_size[1]],
results['panoptic_outputs']['instance_mask'].numpy().shape)
else:
self.assertNotIn('panoptic_outputs', results)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
shared_backbone=[True, False], shared_decoder=[True, False])) shared_backbone=[True, False], shared_decoder=[True, False]))
...@@ -319,6 +349,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -319,6 +349,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj = roi_sampler.ROISampler() roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner() roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator() detection_generator_obj = detection_generator.DetectionGenerator()
panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=[None, None],
max_num_detections=100,
stuff_classes_offset=90)
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16 segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
...@@ -356,6 +390,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -356,6 +390,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj, roi_sampler_obj,
roi_aligner_obj, roi_aligner_obj,
detection_generator_obj, detection_generator_obj,
panoptic_segmentation_generator_obj,
mask_head, mask_head,
mask_sampler_obj, mask_sampler_obj,
mask_roi_aligner_obj, mask_roi_aligner_obj,
...@@ -393,6 +428,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -393,6 +428,10 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj = roi_sampler.ROISampler() roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner() roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator() detection_generator_obj = detection_generator.DetectionGenerator()
panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=[None, None],
max_num_detections=100,
stuff_classes_offset=90)
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16 segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
...@@ -430,6 +469,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -430,6 +469,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
roi_sampler_obj, roi_sampler_obj,
roi_aligner_obj, roi_aligner_obj,
detection_generator_obj, detection_generator_obj,
panoptic_segmentation_generator_obj,
mask_head, mask_head,
mask_sampler_obj, mask_sampler_obj,
mask_roi_aligner_obj, mask_roi_aligner_obj,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Panoptic MaskRCNN task definition.""" """Panoptic MaskRCNN task definition."""
from typing import Any, List, Mapping, Optional, Tuple, Dict from typing import Any, Dict, List, Mapping, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -22,6 +22,7 @@ from official.common import dataset_fn ...@@ -22,6 +22,7 @@ from official.common import dataset_fn
from official.core import task_factory from official.core import task_factory
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import panoptic_quality_evaluator
from official.vision.beta.evaluation import segmentation_metrics from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses from official.vision.beta.losses import segmentation_losses
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as exp_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as exp_cfg
...@@ -66,6 +67,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -66,6 +67,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
return return
def _get_checkpoint_path(checkpoint_dir_or_file): def _get_checkpoint_path(checkpoint_dir_or_file):
checkpoint_path = checkpoint_dir_or_file
if tf.io.gfile.isdir(checkpoint_dir_or_file): if tf.io.gfile.isdir(checkpoint_dir_or_file):
checkpoint_path = tf.train.latest_checkpoint( checkpoint_path = tf.train.latest_checkpoint(
checkpoint_dir_or_file) checkpoint_dir_or_file)
...@@ -207,6 +209,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -207,6 +209,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
tf.keras.metrics.Metric]: tf.keras.metrics.Metric]:
"""Build detection metrics.""" """Build detection metrics."""
metrics = [] metrics = []
num_segmentation_classes = self.task_config.model.segmentation_model.num_classes
if training: if training:
metric_names = [ metric_names = [
'total_loss', 'total_loss',
...@@ -225,7 +228,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -225,7 +228,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
if self.task_config.segmentation_evaluation.report_train_mean_iou: if self.task_config.segmentation_evaluation.report_train_mean_iou:
self.segmentation_train_mean_iou = segmentation_metrics.MeanIoU( self.segmentation_train_mean_iou = segmentation_metrics.MeanIoU(
name='train_mean_iou', name='train_mean_iou',
num_classes=self.task_config.model.num_classes, num_classes=num_segmentation_classes,
rescale_predictions=False, rescale_predictions=False,
dtype=tf.float32) dtype=tf.float32)
...@@ -239,9 +242,22 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -239,9 +242,22 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_resize_eval_groundtruth) .segmentation_resize_eval_groundtruth)
self.segmentation_perclass_iou_metric = segmentation_metrics.PerClassIoU( self.segmentation_perclass_iou_metric = segmentation_metrics.PerClassIoU(
name='per_class_iou', name='per_class_iou',
num_classes=self.task_config.model.num_classes, num_classes=num_segmentation_classes,
rescale_predictions=rescale_predictions, rescale_predictions=rescale_predictions,
dtype=tf.float32) dtype=tf.float32)
if self.task_config.model.generate_panoptic_masks:
if not self.task_config.validation_data.parser.include_panoptic_masks:
raise ValueError('`include_panoptic_masks` should be set to True when'
' computing panoptic quality.')
pq_config = self.task_config.panoptic_quality_evaluator
self.panoptic_quality_metric = panoptic_quality_evaluator.PanopticQualityEvaluator(
num_categories=pq_config.num_categories,
ignored_label=pq_config.ignored_label,
max_instances_per_category=pq_config.max_instances_per_category,
offset=pq_config.offset,
is_thing=pq_config.is_thing)
return metrics return metrics
def train_step(self, def train_step(self,
...@@ -359,6 +375,16 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -359,6 +375,16 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
segmentation_labels, segmentation_labels,
outputs['segmentation_outputs']) outputs['segmentation_outputs'])
}) })
if self.task_config.model.generate_panoptic_masks:
pq_metric_labels = {
'category_mask':
labels['groundtruths']['gt_panoptic_category_mask'],
'instance_mask':
labels['groundtruths']['gt_panoptic_instance_mask']
}
logs.update({
self.panoptic_quality_metric.name:
(pq_metric_labels, outputs['panoptic_outputs'])})
return logs return logs
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
...@@ -366,6 +392,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -366,6 +392,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.coco_metric.reset_states() self.coco_metric.reset_states()
self.segmentation_perclass_iou_metric.reset_states() self.segmentation_perclass_iou_metric.reset_states()
state = [self.coco_metric, self.segmentation_perclass_iou_metric] state = [self.coco_metric, self.segmentation_perclass_iou_metric]
if self.task_config.model.generate_panoptic_masks:
state += [self.panoptic_quality_metric]
self.coco_metric.update_state( self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0], step_outputs[self.coco_metric.name][0],
...@@ -373,6 +401,12 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -373,6 +401,12 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.segmentation_perclass_iou_metric.update_state( self.segmentation_perclass_iou_metric.update_state(
step_outputs[self.segmentation_perclass_iou_metric.name][0], step_outputs[self.segmentation_perclass_iou_metric.name][0],
step_outputs[self.segmentation_perclass_iou_metric.name][1]) step_outputs[self.segmentation_perclass_iou_metric.name][1])
if self.task_config.model.generate_panoptic_masks:
self.panoptic_quality_metric.update_state(
step_outputs[self.panoptic_quality_metric.name][0],
step_outputs[self.panoptic_quality_metric.name][1])
return state return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
...@@ -388,4 +422,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -388,4 +422,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
result.update({'segmentation_iou/class_{}'.format(i): value}) result.update({'segmentation_iou/class_{}'.format(i): value})
# Computes mean IoU # Computes mean IoU
result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()}) result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})
if self.task_config.model.generate_panoptic_masks:
for k, value in self.panoptic_quality_metric.result().items():
result['panoptic_quality/' + k] = value
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment