Unverified Commit b851571d authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

build `PanopticMaskRCNNModel` in factory

parent 12ef2884
......@@ -22,6 +22,7 @@ from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
def build_panoptic_maskrcnn(
......@@ -73,6 +74,8 @@ def build_panoptic_maskrcnn(
segmentation_head_config = segmentation_config.head
detection_head_config = model_config.detection_head
postprocessing_config = \
model_config.panoptic_segmentation_generator
segmentation_head = segmentation_heads.SegmentationHead(
num_classes=segmentation_config.num_classes,
......@@ -90,6 +93,16 @@ def build_panoptic_maskrcnn(
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
panoptic_segmentation_generator_obj = \
panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=postprocessing_config.output_size,
stuff_classes_offset=postprocessing_config.stuff_classes_offset,
mask_binarize_threshold=postprocessing_config.mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold,
things_class_label=postprocessing_config.things_class_label,
void_class_label=postprocessing_config.void_class_label,
void_instance_id=postprocessing_config.void_instance_id)
# Combines maskrcnn, and segmentation models to build panoptic segmentation
# model.
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
......@@ -101,6 +114,7 @@ def build_panoptic_maskrcnn(
roi_sampler=maskrcnn_model.roi_sampler,
roi_aligner=maskrcnn_model.roi_aligner,
detection_generator=maskrcnn_model.detection_generator,
panoptic_segmentation_generator=panoptic_segmentation_generator_obj,
mask_head=maskrcnn_model.mask_head,
mask_sampler=maskrcnn_model.mask_sampler,
mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment