Commit e8beb1f4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 449345313
parent 1051697d
......@@ -13,8 +13,9 @@
# limitations under the License.
"""MaskRCNN task definition."""
import os
from typing import Any, Optional, List, Tuple, Mapping
from typing import Any, Dict, Optional, List, Tuple, Mapping
from absl import logging
import tensorflow as tf
......@@ -165,29 +166,30 @@ class MaskRCNNTask(base_task.Task):
return dataset
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build Mask R-CNN losses."""
params = self.task_config
cascade_ious = params.model.roi_sampler.cascade_iou_thresholds
def _build_rpn_losses(
self, outputs: Mapping[str, Any],
labels: Mapping[str, Any]) -> Tuple[tf.Tensor, tf.Tensor]:
"""Build losses for Region Proposal Network (RPN)."""
rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
tf.shape(outputs['box_outputs'])[1])
rpn_box_loss_fn = maskrcnn_losses.RpnBoxLoss(
params.losses.rpn_huber_loss_delta)
self.task_config.losses.rpn_huber_loss_delta)
rpn_score_loss = tf.reduce_mean(
rpn_score_loss_fn(
outputs['rpn_scores'], labels['rpn_score_targets']))
rpn_score_loss_fn(outputs['rpn_scores'], labels['rpn_score_targets']))
rpn_box_loss = tf.reduce_mean(
rpn_box_loss_fn(
outputs['rpn_boxes'], labels['rpn_box_targets']))
rpn_box_loss_fn(outputs['rpn_boxes'], labels['rpn_box_targets']))
return rpn_score_loss, rpn_box_loss
def _build_frcnn_losses(
self, outputs: Mapping[str, Any],
labels: Mapping[str, Any]) -> Tuple[tf.Tensor, tf.Tensor]:
"""Build losses for Fast R-CNN."""
cascade_ious = self.task_config.model.roi_sampler.cascade_iou_thresholds
frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
params.losses.frcnn_huber_loss_delta,
params.model.detection_head.class_agnostic_bbox_pred)
self.task_config.losses.frcnn_huber_loss_delta,
self.task_config.model.detection_head.class_agnostic_bbox_pred)
# Final cls/box losses are computed as an average of all detection heads.
frcnn_cls_loss = 0.0
......@@ -212,23 +214,33 @@ class MaskRCNNTask(base_task.Task):
frcnn_box_loss += frcnn_box_loss_i
frcnn_cls_loss /= num_det_heads
frcnn_box_loss /= num_det_heads
return frcnn_cls_loss, frcnn_box_loss
def _build_mask_loss(self, outputs: Mapping[str, Any]) -> tf.Tensor:
"""Build losses for the masks."""
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
mask_class_targets = outputs['mask_class_targets']
if self.task_config.allowed_mask_class_ids is not None:
# Classes with ID=0 are ignored by mask_loss_fn in loss computation.
mask_class_targets = zero_out_disallowed_class_ids(
mask_class_targets, self.task_config.allowed_mask_class_ids)
return tf.reduce_mean(
mask_loss_fn(outputs['mask_outputs'], outputs['mask_targets'],
mask_class_targets))
if params.model.include_mask:
mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
mask_class_targets = outputs['mask_class_targets']
if self._task_config.allowed_mask_class_ids is not None:
# Classes with ID=0 are ignored by mask_loss_fn in loss computation.
mask_class_targets = zero_out_disallowed_class_ids(
mask_class_targets, self._task_config.allowed_mask_class_ids)
mask_loss = tf.reduce_mean(
mask_loss_fn(
outputs['mask_outputs'],
outputs['mask_targets'],
mask_class_targets))
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None) -> Dict[str, tf.Tensor]:
"""Build Mask R-CNN losses."""
rpn_score_loss, rpn_box_loss = self._build_rpn_losses(outputs, labels)
frcnn_cls_loss, frcnn_box_loss = self._build_frcnn_losses(outputs, labels)
if self.task_config.model.include_mask:
mask_loss = self._build_mask_loss(outputs)
else:
mask_loss = 0.0
mask_loss = tf.constant(0.0, dtype=tf.float32)
params = self.task_config
model_loss = (
params.losses.rpn_score_weight * rpn_score_loss +
params.losses.rpn_box_weight * rpn_box_loss +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment