Unverified Commit 6e0d65cb authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

refactor losses; implemented weighted semantic loss

parent abfd0698
...@@ -12,54 +12,112 @@ ...@@ -12,54 +12,112 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Instance center losses used for panoptic deeplab model.""" """Losses used for panoptic deeplab model."""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.projects.panoptic_maskrcnn.ops import mask_ops
EPSILON = 1e-5
class WeightedBootstrappedCrossEntropyLoss:
"""Weighted semantic segmentation loss."""
def __init__(self, label_smoothing, class_weights, ignore_label,
top_k_percent_pixels=1.0):
self._top_k_percent_pixels = top_k_percent_pixels
self._class_weights = class_weights
self._ignore_label = ignore_label
self._label_smoothing = label_smoothing
class CenterLoss: def __call__(self, logits, labels, sample_weight=None):
"""Instance center loss.""" _, _, _, num_classes = logits.get_shape().as_list()
_LOSS_FN = {
'mse': tf.losses.mean_squared_error,
'mae': tf.losses.mean_absolute_error
}
def __init__(self, use_groundtruth_dimension: bool, loss_type: str): logits = tf.image.resize(
if loss_type.lower() not in {'mse', 'mae'}: logits, tf.shape(labels)[1:3],
raise ValueError('Unsupported `loss_type` supported. Available loss ' method=tf.image.ResizeMethod.BILINEAR)
'types: mse/mae')
self._use_groundtruth_dimension = use_groundtruth_dimension valid_mask = tf.not_equal(labels, self._ignore_label)
self.loss_type = loss_type normalizer = tf.reduce_sum(tf.cast(valid_mask, tf.float32)) + EPSILON
self._loss_fn = CenterLoss._LOSS_FN[self.loss_type] # Assign pixel with ignore label to class 0 (background). The loss on the
# pixel will later be masked out.
labels = tf.where(valid_mask, labels, tf.zeros_like(labels))
def __call__(self, logits, labels, sample_weight): labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3)
_, height, width, _ = logits.get_shape().as_list() valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3)
onehot_labels = tf.one_hot(labels, num_classes)
onehot_labels = onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits)
if self._use_groundtruth_dimension: if not self._class_weights:
logits = tf.image.resize( class_weights = [1] * num_classes
logits, tf.shape(labels)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
else: else:
labels = tf.image.resize( class_weights = self._class_weights
labels, (height, width),
method=tf.image.ResizeMethod.BILINEAR) if num_classes != len(class_weights):
raise ValueError(
'Length of class_weights should be {}'.format(num_classes))
weight_mask = tf.einsum('...y,y->...',
tf.one_hot(labels, num_classes, dtype=tf.float32),
tf.constant(class_weights, tf.float32))
valid_mask *= weight_mask
if sample_weight is not None:
valid_mask *= sample_weight
cross_entropy_loss *= tf.cast(valid_mask, tf.float32)
if self._top_k_percent_pixels >= 1.0:
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
else:
cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
loss = tf.reduce_sum(top_k_losses) / normalizer
return loss
class CenterHeatmapLoss:
def __init__(self):
self._loss_fn = tf.losses.mean_squared_error
def __call__(self, logits, labels, sample_weight=None):
_, height, width, _ = labels.get_shape().as_list()
logits = tf.image.resize(
logits,
size=[height, width],
method=tf.image.ResizeMethod.BILINEAR)
loss = self._loss_fn(y_true=labels, y_pred=logits) loss = self._loss_fn(y_true=labels, y_pred=logits)
return tf_utils.safe_mean(loss * sample_weight)
if sample_weight is not None:
loss *= sample_weight
return tf_utils.safe_mean(loss)
class CenterOffsetLoss:
def __init__(self):
self._loss_fn = tf.losses.mean_absolute_error
def __call__(self, logits, labels, sample_weight=None):
_, height, width, _ = labels.get_shape().as_list()
logits = mask_ops.resize_and_rescale_offsets(
logits, target_size=[height, width])
loss = self._loss_fn(y_true=labels, y_pred=logits)
class CenterHeatmapLoss(CenterLoss): if sample_weight is not None:
def __init__(self, use_groundtruth_dimension): loss *= sample_weight
super(CenterHeatmapLoss, self).__init__(
use_groundtruth_dimension=use_groundtruth_dimension,
loss_type='mse')
class CenterOffsetLoss(CenterLoss): return tf_utils.safe_mean(loss)
def __init__(self, use_groundtruth_dimension):
super(CenterOffsetLoss, self).__init__(
use_groundtruth_dimension=use_groundtruth_dimension,
loss_type='mae')
...@@ -28,7 +28,6 @@ from official.vision.beta.projects.panoptic_maskrcnn.losses import panoptic_deep ...@@ -28,7 +28,6 @@ from official.vision.beta.projects.panoptic_maskrcnn.losses import panoptic_deep
from official.vision.dataloaders import input_reader_factory from official.vision.dataloaders import input_reader_factory
from official.vision.evaluation import panoptic_quality_evaluator from official.vision.evaluation import panoptic_quality_evaluator
from official.vision.evaluation import segmentation_metrics from official.vision.evaluation import segmentation_metrics
from official.vision.losses import segmentation_losses
@task_factory.register_task_cls(exp_cfg.PanopticDeeplabTask) @task_factory.register_task_cls(exp_cfg.PanopticDeeplabTask)
...@@ -131,28 +130,29 @@ class PanopticDeeplabTask(base_task.Task): ...@@ -131,28 +130,29 @@ class PanopticDeeplabTask(base_task.Task):
The total loss tensor. The total loss tensor.
""" """
loss_config = self._task_config.losses loss_config = self._task_config.losses
segmentation_loss_fn = segmentation_losses.SegmentationLoss( segmentation_loss_fn = panoptic_deeplab_losses.WeightedBootstrappedCrossEntropyLoss(
loss_config.label_smoothing, loss_config.label_smoothing,
loss_config.class_weights, loss_config.class_weights,
loss_config.ignore_label, loss_config.ignore_label,
use_groundtruth_dimension=loss_config.use_groundtruth_dimension,
top_k_percent_pixels=loss_config.top_k_percent_pixels) top_k_percent_pixels=loss_config.top_k_percent_pixels)
instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss( instance_center_heatmap_loss_fn = panoptic_deeplab_losses.CenterHeatmapLoss()
use_groundtruth_dimension=loss_config.use_groundtruth_dimension) instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss()
instance_center_offset_loss_fn = panoptic_deeplab_losses.CenterOffsetLoss(
use_groundtruth_dimension=loss_config.use_groundtruth_dimension)
segmentation_loss = segmentation_loss_fn(
model_outputs['segmentation_outputs'],
labels['category_mask'])
semantic_weights = tf.cast(
labels['semantic_weights'],
dtype=model_outputs['instance_centers_heatmap'].dtype)
things_mask = tf.cast( things_mask = tf.cast(
tf.squeeze(labels['things_mask'], axis=3), labels['things_mask'],
dtype=model_outputs['instance_centers_heatmap'].dtype) dtype=model_outputs['instance_centers_heatmap'].dtype)
valid_mask = tf.cast( valid_mask = tf.cast(
tf.squeeze(labels['valid_mask'], axis=3), labels['valid_mask'],
dtype=model_outputs['instance_centers_heatmap'].dtype) dtype=model_outputs['instance_centers_heatmap'].dtype)
segmentation_loss = segmentation_loss_fn(
model_outputs['segmentation_outputs'],
labels['category_mask'],
sample_weight=semantic_weights)
instance_center_heatmap_loss = instance_center_heatmap_loss_fn( instance_center_heatmap_loss = instance_center_heatmap_loss_fn(
model_outputs['instance_centers_heatmap'], model_outputs['instance_centers_heatmap'],
labels['instance_centers_heatmap'], labels['instance_centers_heatmap'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment