Commit 35c3a79f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by Jaeyoun Kim
Browse files

fixed linting errors (#10053)

parent b3be14bc
...@@ -78,7 +78,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -78,7 +78,7 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
will allow the segmentation head to use a standlone decoder. Setting will allow the segmentation head to use a standlone decoder. Setting
`segmentation_decoder=None` would enable decoder sharing between `segmentation_decoder=None` would enable decoder sharing between
the MaskRCNN model and segmentation head. Decoders can only be shared the MaskRCNN model and segmentation head. Decoders can only be shared
when `segmentation_backbone` is shared as well. when `segmentation_backbone` is shared as well.
segmentation_head: segmentatation head for panoptic task. segmentation_head: segmentatation head for panoptic task.
class_agnostic_bbox_pred: if True, perform class agnostic bounding box class_agnostic_bbox_pred: if True, perform class agnostic bounding box
prediction. Needs to be `True` for Cascade RCNN models. prediction. Needs to be `True` for Cascade RCNN models.
...@@ -98,28 +98,28 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -98,28 +98,28 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(PanopticMaskRCNNModel, self).__init__( super(PanopticMaskRCNNModel, self).__init__(
backbone=backbone, backbone=backbone,
decoder=decoder, decoder=decoder,
rpn_head=rpn_head, rpn_head=rpn_head,
detection_head=detection_head, detection_head=detection_head,
roi_generator=roi_generator, roi_generator=roi_generator,
roi_sampler=roi_sampler, roi_sampler=roi_sampler,
roi_aligner=roi_aligner, roi_aligner=roi_aligner,
detection_generator=detection_generator, detection_generator=detection_generator,
mask_head=mask_head, mask_head=mask_head,
mask_sampler=mask_sampler, mask_sampler=mask_sampler,
mask_roi_aligner=mask_roi_aligner, mask_roi_aligner=mask_roi_aligner,
class_agnostic_bbox_pred=class_agnostic_bbox_pred, class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble, cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
num_scales=num_scales, num_scales=num_scales,
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
anchor_size=anchor_size, anchor_size=anchor_size,
**kwargs) **kwargs)
self._config_dict.update({ self._config_dict.update({
'segmentation_backbone':segmentation_backbone, 'segmentation_backbone': segmentation_backbone,
'segmentation_decoder': segmentation_decoder, 'segmentation_decoder': segmentation_decoder,
'segmentation_head': segmentation_head 'segmentation_head': segmentation_head
}) })
...@@ -129,7 +129,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -129,7 +129,8 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
'`mask_head` needs to be provided for Panoptic Mask R-CNN.') '`mask_head` needs to be provided for Panoptic Mask R-CNN.')
if segmentation_backbone is not None and segmentation_decoder is None: if segmentation_backbone is not None and segmentation_decoder is None:
raise ValueError( raise ValueError(
'`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN if `backbone` is not shared.' '`segmentation_decoder` needs to be provided for Panoptic Mask R-CNN'\
'if `backbone` is not shared.'
) )
self.segmentation_backbone = segmentation_backbone self.segmentation_backbone = segmentation_backbone
...@@ -270,18 +271,18 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel): ...@@ -270,18 +271,18 @@ class PanopticMaskRCNNModel(maskrcnn_model.MaskRCNNModel):
if self.segmentation_backbone is not None: if self.segmentation_backbone is not None:
backbone_features = self.segmentation_backbone( backbone_features = self.segmentation_backbone(
images, images,
training=training) training=training)
if self.segmentation_decoder is not None: if self.segmentation_decoder is not None:
decoder_features = self.segmentation_decoder( decoder_features = self.segmentation_decoder(
backbone_features,
training=training)
segmentation_outputs = self.segmentation_head(
backbone_features, backbone_features,
decoder_features,
training=training) training=training)
segmentation_outputs = self.segmentation_head(
backbone_features,
decoder_features,
training=training)
model_outputs.update({ model_outputs.update({
'segmentation_outputs': segmentation_outputs, 'segmentation_outputs': segmentation_outputs,
......
...@@ -23,7 +23,8 @@ import tensorflow as tf ...@@ -23,7 +23,8 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import \
panoptic_maskrcnn_model
from official.vision.beta.modeling.backbones import resnet from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import fpn from official.vision.beta.modeling.decoders import fpn
from official.vision.beta.modeling.decoders import aspp from official.vision.beta.modeling.decoders import aspp
...@@ -111,14 +112,14 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -111,14 +112,14 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
if not shared_decoder: if not shared_decoder:
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
backbone, backbone,
...@@ -316,9 +317,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -316,9 +317,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
self.assertIn('detection_masks', results) self.assertIn('detection_masks', results)
self.assertIn('segmentation_outputs', results) self.assertIn('segmentation_outputs', results)
self.assertAllEqual( self.assertAllEqual(
[2, [2, image_size[0] // (2**level),
image_size[0] // (2**level), image_size[1] // (2**level), 2],
image_size[1] // (2**level), 2],
results['segmentation_outputs'].numpy().shape) results['segmentation_outputs'].numpy().shape)
@combinations.generate( @combinations.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment