Unverified Commit add9b37e authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added `instance_segmentation_weight` config param

parent b547ab67
......@@ -106,6 +106,7 @@ class Losses(maskrcnn.Losses):
default_factory=list)
semantic_segmentation_use_groundtruth_dimension: bool = True
semantic_segmentation_top_k_percent_pixels: float = 1.0
instance_segmentation_weight: float = 1.0
semantic_segmentation_weight: float = 1.0
......
......@@ -178,6 +178,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
ignore_label=params.semantic_segmentation_ignore_label,
use_groundtruth_dimension=use_groundtruth_dimension,
top_k_percent_pixels=params.semantic_segmentation_top_k_percent_pixels)
instance_segmentation_weight = params.instance_segmentation_weight
semantic_segmentation_weight = params.semantic_segmentation_weight
losses = super(PanopticMaskRCNNTask, self).build_losses(
......@@ -190,7 +192,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
labels['gt_segmentation_mask'])
model_loss = (
maskrcnn_loss + semantic_segmentation_weight * segmentation_loss)
instance_segmentation_weight * maskrcnn_loss +
semantic_segmentation_weight * segmentation_loss)
total_loss = model_loss
if aux_losses:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment