Unverified Commit 6e0d65cb authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

refactor losses; implemented weighted semantic loss

parent abfd0698
......@@ -12,54 +12,112 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Instance center losses used for panoptic deeplab model."""
"""Losses used for panoptic deeplab model."""
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling import tf_utils
from official.vision.beta.projects.panoptic_maskrcnn.ops import mask_ops
class CenterLoss:
"""Instance center loss."""
EPSILON = 1e-5
_LOSS_FN = {
'mse': tf.losses.mean_squared_error,
'mae': tf.losses.mean_absolute_error
}
def __init__(self, use_groundtruth_dimension: bool, loss_type: str):
if loss_type.lower() not in {'mse', 'mae'}:
raise ValueError('Unsupported `loss_type` supported. Available loss '
'types: mse/mae')
class WeightedBootstrappedCrossEntropyLoss:
"""Weighted semantic segmentation loss."""
self._use_groundtruth_dimension = use_groundtruth_dimension
self.loss_type = loss_type
self._loss_fn = CenterLoss._LOSS_FN[self.loss_type]
def __init__(self, label_smoothing, class_weights, ignore_label,
top_k_percent_pixels=1.0):
self._top_k_percent_pixels = top_k_percent_pixels
self._class_weights = class_weights
self._ignore_label = ignore_label
self._label_smoothing = label_smoothing
def __call__(self, logits, labels, sample_weight):
_, height, width, _ = logits.get_shape().as_list()
def __call__(self, logits, labels, sample_weight=None):
_, _, _, num_classes = logits.get_shape().as_list()
if self._use_groundtruth_dimension:
logits = tf.image.resize(
logits, tf.shape(labels)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
valid_mask = tf.not_equal(labels, self._ignore_label)
normalizer = tf.reduce_sum(tf.cast(valid_mask, tf.float32)) + EPSILON
# Assign pixel with ignore label to class 0 (background). The loss on the
# pixel will later be masked out.
labels = tf.where(valid_mask, labels, tf.zeros_like(labels))
labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3)
valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3)
onehot_labels = tf.one_hot(labels, num_classes)
onehot_labels = onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits)
if not self._class_weights:
class_weights = [1] * num_classes
else:
class_weights = self._class_weights
if num_classes != len(class_weights):
raise ValueError(
'Length of class_weights should be {}'.format(num_classes))
weight_mask = tf.einsum('...y,y->...',
tf.one_hot(labels, num_classes, dtype=tf.float32),
tf.constant(class_weights, tf.float32))
valid_mask *= weight_mask
if sample_weight is not None:
valid_mask *= sample_weight
cross_entropy_loss *= tf.cast(valid_mask, tf.float32)
if self._top_k_percent_pixels >= 1.0:
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
else:
labels = tf.image.resize(
labels, (height, width),
cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
loss = tf.reduce_sum(top_k_losses) / normalizer
return loss
class CenterHeatmapLoss:
def __init__(self):
self._loss_fn = tf.losses.mean_squared_error
def __call__(self, logits, labels, sample_weight=None):
_, height, width, _ = labels.get_shape().as_list()
logits = tf.image.resize(
logits,
size=[height, width],
method=tf.image.ResizeMethod.BILINEAR)
loss = self._loss_fn(y_true=labels, y_pred=logits)
return tf_utils.safe_mean(loss * sample_weight)
if sample_weight is not None:
loss *= sample_weight
return tf_utils.safe_mean(loss)
class CenterOffsetLoss:
def __init__(self):
self._loss_fn = tf.losses.mean_absolute_error
def __call__(self, logits, labels, sample_weight=None):
_, height, width, _ = labels.get_shape().as_list()
logits = mask_ops.resize_and_rescale_offsets(
logits, target_size=[height, width])
loss = self._loss_fn(y_true=labels, y_pred=logits)
class CenterHeatmapLoss(CenterLoss):
def __init__(self, use_groundtruth_dimension):
super(CenterHeatmapLoss, self).__init__(
use_groundtruth_dimension=use_groundtruth_dimension,
loss_type='mse')
if sample_weight is not None:
loss *= sample_weight
class CenterOffsetLoss(CenterLoss):
def __init__(self, use_groundtruth_dimension):
super(CenterOffsetLoss, self).__init__(
use_groundtruth_dimension=use_groundtruth_dimension,
loss_type='mae')
return tf_utils.safe_mean(loss)
......@@ -28,7 +28,6 @@ from official.vision.beta.projects.panoptic_maskrcnn.losses import panoptic_deep
from official.vision.dataloaders import input_reader_factory
from official.vision.evaluation import panoptic_quality_evaluator
from official.vision.evaluation import segmentation_metrics
from official.vision.losses import segmentation_losses
@task_factory.register_task_cls(exp_cfg.PanopticDeeplabTask)
......@@ -131,28 +130,29 @@ class PanopticDeeplabTask(base_task.Task):
The total loss tensor.
"""
loss_config = self._task_config.losses
segmentation_loss_fn = segmentation_losses.SegmentationLoss(
segmentation_loss_fn = panoptic_deeplab_losses.WeightedBootstrappedCrossEntropyLoss(
loss_config.label_smoothing,
loss_config.class_weights,
loss_config.ignore_label,
use_groundtruth_dimension=loss_config.use_groundtruth_dimension,
top_k_percent_pixels=loss_config.top_k_percent_pixels)
instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss(
use_groundtruth_dimension=loss_config.use_groundtruth_dimension)
instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss(
use_groundtruth_dimension=loss_config.use_groundtruth_dimension)
instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss()
instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss()
segmentation_loss = segmentation_loss_fn(
model_outputs['segmentation_outputs'],
labels['category_mask'])
semantic_weights = tf.cast(
labels['semantic_weights'],
dtype=model_outputs['instance_centers_heatmap'].dtype)
things_mask = tf.cast(
tf.squeeze(labels['things_mask'], axis=3),
labels['things_mask'],
dtype=model_outputs['instance_centers_heatmap'].dtype)
valid_mask = tf.cast(
tf.squeeze(labels['valid_mask'], axis=3),
labels['valid_mask'],
dtype=model_outputs['instance_centers_heatmap'].dtype)
segmentation_loss = segmentation_loss_fn(
model_outputs['segmentation_outputs'],
labels['category_mask'],
sample_weight=semantic_weights)
instance_center_heatmap_loss = instance_center_heatmap_loss_fn(
model_outputs['instance_centers_heatmap'],
labels['instance_centers_heatmap'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment