Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bb6b143c
Commit
bb6b143c
authored
Apr 20, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 443222400
parent
2ce12046
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
9 deletions
+25
-9
official/vision/losses/segmentation_losses.py
official/vision/losses/segmentation_losses.py
+25
-9
No files found.
official/vision/losses/segmentation_losses.py
View file @
bb6b143c
...
@@ -33,14 +33,13 @@ class SegmentationLoss:
...
@@ -33,14 +33,13 @@ class SegmentationLoss:
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_label_smoothing
=
label_smoothing
self
.
_label_smoothing
=
label_smoothing
def
__call__
(
self
,
logits
,
labels
):
def
__call__
(
self
,
logits
,
labels
,
**
kwargs
):
_
,
height
,
width
,
num_classes
=
logits
.
get_shape
().
as_list
()
_
,
height
,
width
,
num_classes
=
logits
.
get_shape
().
as_list
()
if
self
.
_use_groundtruth_dimension
:
if
self
.
_use_groundtruth_dimension
:
# TODO(arashwan): Test using align corners to match deeplab alignment.
# TODO(arashwan): Test using align corners to match deeplab alignment.
logits
=
tf
.
image
.
resize
(
logits
=
tf
.
image
.
resize
(
logits
,
tf
.
shape
(
labels
)[
1
:
3
],
logits
,
tf
.
shape
(
labels
)[
1
:
3
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
else
:
else
:
labels
=
tf
.
image
.
resize
(
labels
=
tf
.
image
.
resize
(
labels
,
(
height
,
width
),
labels
,
(
height
,
width
),
...
@@ -54,11 +53,9 @@ class SegmentationLoss:
...
@@ -54,11 +53,9 @@ class SegmentationLoss:
labels
=
tf
.
squeeze
(
tf
.
cast
(
labels
,
tf
.
int32
),
axis
=
3
)
labels
=
tf
.
squeeze
(
tf
.
cast
(
labels
,
tf
.
int32
),
axis
=
3
)
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=
3
)
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=
3
)
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
onehot_labels
=
onehot_labels
*
(
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
cross_entropy_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
cross_entropy_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
=
onehot_labels
,
logits
=
logits
)
labels
=
self
.
get_labels_with_prob
(
labels
,
logits
,
**
kwargs
),
logits
=
logits
)
if
not
self
.
_class_weights
:
if
not
self
.
_class_weights
:
class_weights
=
[
1
]
*
num_classes
class_weights
=
[
1
]
*
num_classes
...
@@ -90,6 +87,26 @@ class SegmentationLoss:
...
@@ -90,6 +87,26 @@ class SegmentationLoss:
return
loss
return
loss
def
get_labels_with_prob
(
self
,
labels
,
logits
,
**
unused_kwargs
):
"""Get a tensor representing the probability of each class for each pixel.
This method can be overridden in subclasses for customizing loss function.
Args:
labels: A float tensor in shape (batch_size, height, width), which is the
label map of the ground truth.
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
**unused_kwargs: Unused keyword arguments.
Returns:
A float tensor in shape (batch_size, height, width, num_classes).
"""
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
return
onehot_labels
*
(
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
def
get_actual_mask_scores
(
logits
,
labels
,
ignore_label
):
def
get_actual_mask_scores
(
logits
,
labels
,
ignore_label
):
"""Gets actual mask scores."""
"""Gets actual mask scores."""
...
@@ -97,8 +114,7 @@ def get_actual_mask_scores(logits, labels, ignore_label):
...
@@ -97,8 +114,7 @@ def get_actual_mask_scores(logits, labels, ignore_label):
batch_size
=
tf
.
shape
(
logits
)[
0
]
batch_size
=
tf
.
shape
(
logits
)[
0
]
logits
=
tf
.
stop_gradient
(
logits
)
logits
=
tf
.
stop_gradient
(
logits
)
labels
=
tf
.
image
.
resize
(
labels
=
tf
.
image
.
resize
(
labels
,
(
height
,
width
),
labels
,
(
height
,
width
),
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
predicted_labels
=
tf
.
argmax
(
logits
,
-
1
,
output_type
=
tf
.
int32
)
predicted_labels
=
tf
.
argmax
(
logits
,
-
1
,
output_type
=
tf
.
int32
)
flat_predictions
=
tf
.
reshape
(
predicted_labels
,
[
batch_size
,
-
1
])
flat_predictions
=
tf
.
reshape
(
predicted_labels
,
[
batch_size
,
-
1
])
flat_labels
=
tf
.
cast
(
tf
.
reshape
(
labels
,
[
batch_size
,
-
1
]),
tf
.
int32
)
flat_labels
=
tf
.
cast
(
tf
.
reshape
(
labels
,
[
batch_size
,
-
1
]),
tf
.
int32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment