"vscode:/vscode.git/clone" did not exist on "ea22f8ec9de7f4067f265ebda483c48650c3595e"
Commit bb6b143c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443222400
parent 2ce12046
......@@ -33,14 +33,13 @@ class SegmentationLoss:
self._use_groundtruth_dimension = use_groundtruth_dimension
self._label_smoothing = label_smoothing
def __call__(self, logits, labels):
def __call__(self, logits, labels, **kwargs):
_, height, width, num_classes = logits.get_shape().as_list()
if self._use_groundtruth_dimension:
# TODO(arashwan): Test using align corners to match deeplab alignment.
logits = tf.image.resize(
logits, tf.shape(labels)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
logits, tf.shape(labels)[1:3], method=tf.image.ResizeMethod.BILINEAR)
else:
labels = tf.image.resize(
labels, (height, width),
......@@ -54,11 +53,9 @@ class SegmentationLoss:
labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3)
valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3)
onehot_labels = tf.one_hot(labels, num_classes)
onehot_labels = onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits)
labels=self.get_labels_with_prob(labels, logits, **kwargs),
logits=logits)
if not self._class_weights:
class_weights = [1] * num_classes
......@@ -90,6 +87,26 @@ class SegmentationLoss:
return loss
def get_labels_with_prob(self, labels, logits, **unused_kwargs):
"""Get a tensor representing the probability of each class for each pixel.
This method can be overridden in subclasses for customizing loss function.
Args:
labels: A float tensor in shape (batch_size, height, width), which is the
label map of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
**unused_kwargs: Unused keyword arguments.
Returns:
A float tensor in shape (batch_size, height, width, num_classes).
"""
num_classes = logits.get_shape().as_list()[-1]
onehot_labels = tf.one_hot(labels, num_classes)
return onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
def get_actual_mask_scores(logits, labels, ignore_label):
"""Gets actual mask scores."""
......@@ -97,8 +114,7 @@ def get_actual_mask_scores(logits, labels, ignore_label):
batch_size = tf.shape(logits)[0]
logits = tf.stop_gradient(logits)
labels = tf.image.resize(
labels, (height, width),
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
labels, (height, width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
predicted_labels = tf.argmax(logits, -1, output_type=tf.int32)
flat_predictions = tf.reshape(predicted_labels, [batch_size, -1])
flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment