Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
68d9973c
Commit
68d9973c
authored
Aug 24, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 469755043
parent
c04133d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
104 additions
and
31 deletions
+104
-31
official/vision/losses/segmentation_losses.py
official/vision/losses/segmentation_losses.py
+104
-31
No files found.
official/vision/losses/segmentation_losses.py
View file @
68d9973c
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
"""Losses used for segmentation models."""
"""Losses used for segmentation models."""
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
...
@@ -25,16 +24,45 @@ EPSILON = 1e-5
...
@@ -25,16 +24,45 @@ EPSILON = 1e-5
class
SegmentationLoss
:
class
SegmentationLoss
:
"""Semantic segmentation loss."""
"""Semantic segmentation loss."""
def
__init__
(
self
,
label_smoothing
,
class_weights
,
ignore_label
,
def
__init__
(
self
,
use_groundtruth_dimension
,
top_k_percent_pixels
=
1.0
):
label_smoothing
,
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
class_weights
,
ignore_label
,
use_groundtruth_dimension
,
top_k_percent_pixels
=
1.0
):
"""Initializes `SegmentationLoss`.
Args:
label_smoothing: A float, if > 0., smooth out one-hot probability by
spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label.
use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
value < 1., only compute the loss for the top k percent pixels. This is
useful for hard pixel mining.
"""
self
.
_label_smoothing
=
label_smoothing
self
.
_class_weights
=
class_weights
self
.
_class_weights
=
class_weights
self
.
_ignore_label
=
ignore_label
self
.
_ignore_label
=
ignore_label
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_
label_smoothing
=
label_smoothing
self
.
_
top_k_percent_pixels
=
top_k_percent_pixels
def
__call__
(
self
,
logits
,
labels
,
**
kwargs
):
def
__call__
(
self
,
logits
,
labels
,
**
kwargs
):
_
,
height
,
width
,
num_classes
=
logits
.
get_shape
().
as_list
()
"""Computes `SegmentationLoss`.
Args:
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
labels: A tensor in shape (batch_size, height, width, 1), which is the
label mask of the ground truth.
**kwargs: additional keyword arguments.
Returns:
A 0-D float which stores the overall loss of the batch.
"""
_
,
height
,
width
,
_
=
logits
.
get_shape
().
as_list
()
if
self
.
_use_groundtruth_dimension
:
if
self
.
_use_groundtruth_dimension
:
# TODO(arashwan): Test using align corners to match deeplab alignment.
# TODO(arashwan): Test using align corners to match deeplab alignment.
...
@@ -45,14 +73,38 @@ class SegmentationLoss:
...
@@ -45,14 +73,38 @@ class SegmentationLoss:
labels
,
(
height
,
width
),
labels
,
(
height
,
width
),
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
valid_mask
=
tf
.
not_equal
(
labels
,
self
.
_ignore_label
)
valid_mask
=
tf
.
not_equal
(
labels
,
self
.
_ignore_label
)
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
valid_mask
,
tf
.
float32
))
+
EPSILON
cross_entropy_loss
=
self
.
compute_pixelwise_loss
(
labels
,
logits
,
valid_mask
,
**
kwargs
)
if
self
.
_top_k_percent_pixels
<
1.0
:
return
self
.
aggregate_loss_top_k
(
cross_entropy_loss
)
else
:
return
self
.
aggregate_loss
(
cross_entropy_loss
,
valid_mask
)
def
compute_pixelwise_loss
(
self
,
labels
,
logits
,
valid_mask
,
**
kwargs
):
"""Computes the loss for each pixel.
Args:
labels: An int32 tensor in shape (batch_size, height, width, 1), which is
the label mask of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
valid_mask: A bool tensor in shape (batch_size, height, width, 1) which
masks out ignored pixels.
**kwargs: additional keyword arguments.
Returns:
A float tensor in shape (batch_size, height, width) which stores the loss
value for each pixel.
"""
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
# Assign pixel with ignore label to class 0 (background). The loss on the
# Assign pixel with ignore label to class 0 (background). The loss on the
# pixel will later be masked out.
# pixel will later be masked out.
labels
=
tf
.
where
(
valid_mask
,
labels
,
tf
.
zeros_like
(
labels
))
labels
=
tf
.
where
(
valid_mask
,
labels
,
tf
.
zeros_like
(
labels
))
labels
=
tf
.
squeeze
(
tf
.
cast
(
labels
,
tf
.
int32
),
axis
=
3
)
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=
3
)
cross_entropy_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
cross_entropy_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
=
self
.
get_labels_with_prob
(
labels
,
logits
,
**
kwargs
),
labels
=
self
.
get_labels_with_prob
(
labels
,
logits
,
**
kwargs
),
logits
=
logits
)
logits
=
logits
)
...
@@ -66,26 +118,12 @@ class SegmentationLoss:
...
@@ -66,26 +118,12 @@ class SegmentationLoss:
raise
ValueError
(
raise
ValueError
(
'Length of class_weights should be {}'
.
format
(
num_classes
))
'Length of class_weights should be {}'
.
format
(
num_classes
))
weight_mask
=
tf
.
einsum
(
'...y,y->...'
,
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=-
1
)
tf
.
one_hot
(
labels
,
num_classes
,
dtype
=
tf
.
float32
),
weight_mask
=
tf
.
einsum
(
tf
.
constant
(
class_weights
,
tf
.
float32
))
'...y,y->...'
,
valid_mask
*=
weight_mask
tf
.
one_hot
(
tf
.
squeeze
(
labels
,
axis
=-
1
),
num_classes
,
dtype
=
tf
.
float32
),
cross_entropy_loss
*=
tf
.
cast
(
valid_mask
,
tf
.
float32
)
tf
.
constant
(
class_weights
,
tf
.
float32
))
return
cross_entropy_loss
*
valid_mask
*
weight_mask
if
self
.
_top_k_percent_pixels
>=
1.0
:
loss
=
tf
.
reduce_sum
(
cross_entropy_loss
)
/
normalizer
else
:
cross_entropy_loss
=
tf
.
reshape
(
cross_entropy_loss
,
shape
=
[
-
1
])
top_k_pixels
=
tf
.
cast
(
self
.
_top_k_percent_pixels
*
tf
.
cast
(
tf
.
size
(
cross_entropy_loss
),
tf
.
float32
),
tf
.
int32
)
top_k_losses
,
_
=
tf
.
math
.
top_k
(
cross_entropy_loss
,
k
=
top_k_pixels
,
sorted
=
True
)
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
top_k_losses
,
0.0
),
tf
.
float32
))
+
EPSILON
loss
=
tf
.
reduce_sum
(
top_k_losses
)
/
normalizer
return
loss
def
get_labels_with_prob
(
self
,
labels
,
logits
,
**
unused_kwargs
):
def
get_labels_with_prob
(
self
,
labels
,
logits
,
**
unused_kwargs
):
"""Get a tensor representing the probability of each class for each pixel.
"""Get a tensor representing the probability of each class for each pixel.
...
@@ -93,8 +131,8 @@ class SegmentationLoss:
...
@@ -93,8 +131,8 @@ class SegmentationLoss:
This method can be overridden in subclasses for customizing loss function.
This method can be overridden in subclasses for customizing loss function.
Args:
Args:
labels: A
float
tensor in shape (batch_size, height, width), which is
the
labels: A
n int32
tensor in shape (batch_size, height, width
, 1
), which is
label map of the ground truth.
the
label map of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes)
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
which is the output of the network.
**unused_kwargs: Unused keyword arguments.
**unused_kwargs: Unused keyword arguments.
...
@@ -102,11 +140,46 @@ class SegmentationLoss:
...
@@ -102,11 +140,46 @@ class SegmentationLoss:
Returns:
Returns:
A float tensor in shape (batch_size, height, width, num_classes).
A float tensor in shape (batch_size, height, width, num_classes).
"""
"""
labels
=
tf
.
squeeze
(
labels
,
axis
=-
1
)
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
return
onehot_labels
*
(
return
onehot_labels
*
(
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
def
aggregate_loss
(
self
,
pixelwise_loss
,
valid_mask
):
"""Aggregate the pixelwise loss.
Args:
pixelwise_loss: A float tensor in shape (batch_size, height, width) which
stores the loss of each pixel.
valid_mask: A bool tensor in shape (batch_size, height, width, 1) which
masks out ignored pixels.
Returns:
A 0-D float which stores the overall loss of the batch.
"""
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
valid_mask
,
tf
.
float32
))
+
EPSILON
return
tf
.
reduce_sum
(
pixelwise_loss
)
/
normalizer
def
aggregate_loss_top_k
(
self
,
pixelwise_loss
):
"""Aggregate the top-k greatest pixelwise loss.
Args:
pixelwise_loss: A float tensor in shape (batch_size, height, width) which
stores the loss of each pixel.
Returns:
A 0-D float which stores the overall loss of the batch.
"""
pixelwise_loss
=
tf
.
reshape
(
pixelwise_loss
,
shape
=
[
-
1
])
top_k_pixels
=
tf
.
cast
(
self
.
_top_k_percent_pixels
*
tf
.
cast
(
tf
.
size
(
pixelwise_loss
),
tf
.
float32
),
tf
.
int32
)
top_k_losses
,
_
=
tf
.
math
.
top_k
(
pixelwise_loss
,
k
=
top_k_pixels
,
sorted
=
True
)
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
top_k_losses
,
0.0
),
tf
.
float32
))
+
EPSILON
return
tf
.
reduce_sum
(
top_k_losses
)
/
normalizer
def
get_actual_mask_scores
(
logits
,
labels
,
ignore_label
):
def
get_actual_mask_scores
(
logits
,
labels
,
ignore_label
):
"""Gets actual mask scores."""
"""Gets actual mask scores."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment