Commit 43f2ce0b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443222400
parent c279a62f
...@@ -33,14 +33,13 @@ class SegmentationLoss: ...@@ -33,14 +33,13 @@ class SegmentationLoss:
self._use_groundtruth_dimension = use_groundtruth_dimension self._use_groundtruth_dimension = use_groundtruth_dimension
self._label_smoothing = label_smoothing self._label_smoothing = label_smoothing
def __call__(self, logits, labels): def __call__(self, logits, labels, **kwargs):
_, height, width, num_classes = logits.get_shape().as_list() _, height, width, num_classes = logits.get_shape().as_list()
if self._use_groundtruth_dimension: if self._use_groundtruth_dimension:
# TODO(arashwan): Test using align corners to match deeplab alignment. # TODO(arashwan): Test using align corners to match deeplab alignment.
logits = tf.image.resize( logits = tf.image.resize(
logits, tf.shape(labels)[1:3], logits, tf.shape(labels)[1:3], method=tf.image.ResizeMethod.BILINEAR)
method=tf.image.ResizeMethod.BILINEAR)
else: else:
labels = tf.image.resize( labels = tf.image.resize(
labels, (height, width), labels, (height, width),
...@@ -54,11 +53,9 @@ class SegmentationLoss: ...@@ -54,11 +53,9 @@ class SegmentationLoss:
labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3) labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3)
valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3) valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3)
onehot_labels = tf.one_hot(labels, num_classes)
onehot_labels = onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits( cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits) labels=self.get_labels_with_prob(labels, logits, **kwargs),
logits=logits)
if not self._class_weights: if not self._class_weights:
class_weights = [1] * num_classes class_weights = [1] * num_classes
...@@ -90,6 +87,26 @@ class SegmentationLoss: ...@@ -90,6 +87,26 @@ class SegmentationLoss:
return loss return loss
def get_labels_with_prob(self, labels, logits, **unused_kwargs):
"""Get a tensor representing the probability of each class for each pixel.
This method can be overridden in subclasses for customizing loss function.
Args:
labels: A float tensor in shape (batch_size, height, width), which is the
label map of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
**unused_kwargs: Unused keyword arguments.
Returns:
A float tensor in shape (batch_size, height, width, num_classes).
"""
num_classes = logits.get_shape().as_list()[-1]
onehot_labels = tf.one_hot(labels, num_classes)
return onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
def get_actual_mask_scores(logits, labels, ignore_label): def get_actual_mask_scores(logits, labels, ignore_label):
"""Gets actual mask scores.""" """Gets actual mask scores."""
...@@ -97,8 +114,7 @@ def get_actual_mask_scores(logits, labels, ignore_label): ...@@ -97,8 +114,7 @@ def get_actual_mask_scores(logits, labels, ignore_label):
batch_size = tf.shape(logits)[0] batch_size = tf.shape(logits)[0]
logits = tf.stop_gradient(logits) logits = tf.stop_gradient(logits)
labels = tf.image.resize( labels = tf.image.resize(
labels, (height, width), labels, (height, width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
predicted_labels = tf.argmax(logits, -1, output_type=tf.int32) predicted_labels = tf.argmax(logits, -1, output_type=tf.int32)
flat_predictions = tf.reshape(predicted_labels, [batch_size, -1]) flat_predictions = tf.reshape(predicted_labels, [batch_size, -1])
flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32) flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment