Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8d71f896
Commit
8d71f896
authored
Sep 23, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 476492677
parent
d38e8da1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
7 deletions
+10
-7
official/vision/losses/segmentation_losses.py
official/vision/losses/segmentation_losses.py
+7
-5
official/vision/tasks/semantic_segmentation.py
official/vision/tasks/semantic_segmentation.py
+3
-2
No files found.
official/vision/losses/segmentation_losses.py
View file @
8d71f896
...
@@ -29,9 +29,10 @@ class SegmentationLoss:
...
@@ -29,9 +29,10 @@ class SegmentationLoss:
label_smoothing
,
label_smoothing
,
class_weights
,
class_weights
,
ignore_label
,
ignore_label
,
gt_is_matting_map
,
use_groundtruth_dimension
,
use_groundtruth_dimension
,
top_k_percent_pixels
=
1.0
):
top_k_percent_pixels
=
1.0
,
gt_is_matting_map
=
False
):
"""Initializes `SegmentationLoss`.
"""Initializes `SegmentationLoss`.
Args:
Args:
...
@@ -39,20 +40,21 @@ class SegmentationLoss:
...
@@ -39,20 +40,21 @@ class SegmentationLoss:
spreading the amount of probability to all other label classes.
spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class.
class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label.
ignore_label: An integer specifying the ignore label.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
use_groundtruth_dimension: A boolean, whether to resize the output to
use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth.
match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
value < 1., only compute the loss for the top k percent pixels. This is
value < 1., only compute the loss for the top k percent pixels. This is
useful for hard pixel mining.
useful for hard pixel mining.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
"""
"""
self
.
_label_smoothing
=
label_smoothing
self
.
_label_smoothing
=
label_smoothing
self
.
_class_weights
=
class_weights
self
.
_class_weights
=
class_weights
self
.
_ignore_label
=
ignore_label
self
.
_ignore_label
=
ignore_label
self
.
_gt_is_matting_map
=
gt_is_matting_map
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
self
.
_gt_is_matting_map
=
gt_is_matting_map
def
__call__
(
self
,
logits
,
labels
,
**
kwargs
):
def
__call__
(
self
,
logits
,
labels
,
**
kwargs
):
"""Computes `SegmentationLoss`.
"""Computes `SegmentationLoss`.
...
...
official/vision/tasks/semantic_segmentation.py
View file @
8d71f896
...
@@ -135,9 +135,10 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -135,9 +135,10 @@ class SemanticSegmentationTask(base_task.Task):
loss_params
.
label_smoothing
,
loss_params
.
label_smoothing
,
loss_params
.
class_weights
,
loss_params
.
class_weights
,
loss_params
.
ignore_label
,
loss_params
.
ignore_label
,
loss_params
.
gt_is_matting_map
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
)
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
,
gt_is_matting_map
=
loss_params
.
gt_is_matting_map
)
total_loss
=
segmentation_loss_fn
(
model_outputs
[
'logits'
],
labels
[
'masks'
])
total_loss
=
segmentation_loss_fn
(
model_outputs
[
'logits'
],
labels
[
'masks'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment