Commit 35c3a79f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by Jaeyoun Kim
Browse files

fixed linting errors (#10053)

parent b3be14bc
......@@ -78,7 +78,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
will allow the segmentation head to use a standlone decoder. Setting
`segmentation_decoder=None` would enable decoder sharing between
the MaskRCNN model and segmentation head. Decoders can only be shared
when `segmentation_backbone` is shared as well.
when `segmentation_backbone` is shared as well.
segmentation_head: segmentatation head for panoptic task.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models.
......@@ -98,28 +98,28 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
**kwargs: keyword arguments to be passed.
"""
super(PanopticMaskRCNNModel, self).__init__(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator,
roi_sampler=roi_sampler,
roi_aligner=roi_aligner,
detection_generator=detection_generator,
mask_head=mask_head,
mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
**kwargs)
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator,
roi_sampler=roi_sampler,
roi_aligner=roi_aligner,
detection_generator=detection_generator,
mask_head=mask_head,
mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
**kwargs)
self._config_dict.update({
'segmentation_backbone':segmentation_backbone,
'segmentation_backbone': segmentation_backbone,
'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head
})
......@@ -129,7 +129,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
'`mask_head` needs to be provided for Panoptic Mask R-CNN.')
if segmentation_backbone is not None and segmentation_decoder is None:
raise ValueError(
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN if `backbone` is not shared.'
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN'\
'if `backbone` is not shared.'
)
self.segmentation_backbone = segmentation_backbone
......@@ -270,18 +271,18 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
if self.segmentation_backbone is not None:
backbone_features = self.segmentation_backbone(
images,
training=training)
images,
training=training)
if self.segmentation_decoder is not None:
decoder_features = self.segmentation_decoder(
backbone_features,
training=training)
segmentation_outputs = self.segmentation_head(
backbone_features,
decoder_features,
training=training)
segmentation_outputs = self.segmentation_head(
backbone_features,
decoder_features,
training=training)
model_outputs.update({
'segmentation_outputs': segmentation_outputs,
......
......@@ -23,7 +23,8 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling import \
panoptic_maskrcnn_model
from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import fpn
from official.vision.beta.modeling.decoders import aspp
......@@ -111,14 +112,14 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
if not shared_decoder:
level = aspp_decoder_level
segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates)
level=level, dilation_rates=aspp_dilation_rates)
else:
level = fpn_decoder_level
segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
num_classes=2, # stuff and common class for things,
level=level,
num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone,
......@@ -316,9 +317,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results)
self.assertAllEqual(
[2,
image_size[0] // (2**level),
image_size[1] // (2**level), 2],
[2, image_size[0] // (2**level),
image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape)
@combinations.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment