Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6e0d65cb
Unverified
Commit
6e0d65cb
authored
Mar 20, 2022
by
srihari-humbarwadi
Browse files
refactor losses; implemented weighted semantic loss
parent
abfd0698
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
48 deletions
+106
-48
official/vision/beta/projects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py
...jects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py
+93
-35
official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_deeplab.py
...beta/projects/panoptic_maskrcnn/tasks/panoptic_deeplab.py
+13
-13
No files found.
official/vision/beta/projects/panoptic_maskrcnn/losses/panoptic_deeplab_losses.py
View file @
6e0d65cb
...
@@ -12,54 +12,112 @@
...
@@ -12,54 +12,112 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Instance center l
osses used for panoptic deeplab model."""
"""
L
osses used for panoptic deeplab model."""
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.projects.panoptic_maskrcnn.ops
import
mask_ops
EPSILON
=
1e-5
class
WeightedBootstrappedCrossEntropyLoss
:
"""Weighted semantic segmentation loss."""
def
__init__
(
self
,
label_smoothing
,
class_weights
,
ignore_label
,
top_k_percent_pixels
=
1.0
):
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
self
.
_class_weights
=
class_weights
self
.
_ignore_label
=
ignore_label
self
.
_label_smoothing
=
label_smoothing
class
CenterLoss
:
def
__call__
(
self
,
logits
,
labels
,
sample_weight
=
None
):
"""Instance center loss."""
_
,
_
,
_
,
num_classes
=
logits
.
get_shape
().
as_list
()
_LOSS_FN
=
{
'mse'
:
tf
.
losses
.
mean_squared_error
,
'mae'
:
tf
.
losses
.
mean_absolute_error
}
def
__init__
(
self
,
use_groundtruth_dimension
:
bool
,
loss_type
:
str
):
logits
=
tf
.
image
.
resize
(
if
loss_type
.
lower
()
not
in
{
'mse'
,
'mae'
}:
logits
,
tf
.
shape
(
labels
)[
1
:
3
],
raise
ValueError
(
'Unsupported `loss_type` supported. Available loss '
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
'types: mse/mae'
)
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
valid_mask
=
tf
.
not_equal
(
labels
,
self
.
_ignore_label
)
self
.
loss_type
=
loss_type
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
valid_mask
,
tf
.
float32
))
+
EPSILON
self
.
_loss_fn
=
CenterLoss
.
_LOSS_FN
[
self
.
loss_type
]
# Assign pixel with ignore label to class 0 (background). The loss on the
# pixel will later be masked out.
labels
=
tf
.
where
(
valid_mask
,
labels
,
tf
.
zeros_like
(
labels
))
def
__call__
(
self
,
logits
,
labels
,
sample_weight
):
labels
=
tf
.
squeeze
(
tf
.
cast
(
labels
,
tf
.
int32
),
axis
=
3
)
_
,
height
,
width
,
_
=
logits
.
get_shape
().
as_list
()
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=
3
)
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
onehot_labels
=
onehot_labels
*
(
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
cross_entropy_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
=
onehot_labels
,
logits
=
logits
)
if
self
.
_use_groundtruth_dimension
:
if
not
self
.
_class_weights
:
logits
=
tf
.
image
.
resize
(
class_weights
=
[
1
]
*
num_classes
logits
,
tf
.
shape
(
labels
)[
1
:
3
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
else
:
else
:
labels
=
tf
.
image
.
resize
(
class_weights
=
self
.
_class_weights
labels
,
(
height
,
width
),
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
if
num_classes
!=
len
(
class_weights
):
raise
ValueError
(
'Length of class_weights should be {}'
.
format
(
num_classes
))
weight_mask
=
tf
.
einsum
(
'...y,y->...'
,
tf
.
one_hot
(
labels
,
num_classes
,
dtype
=
tf
.
float32
),
tf
.
constant
(
class_weights
,
tf
.
float32
))
valid_mask
*=
weight_mask
if
sample_weight
is
not
None
:
valid_mask
*=
sample_weight
cross_entropy_loss
*=
tf
.
cast
(
valid_mask
,
tf
.
float32
)
if
self
.
_top_k_percent_pixels
>=
1.0
:
loss
=
tf
.
reduce_sum
(
cross_entropy_loss
)
/
normalizer
else
:
cross_entropy_loss
=
tf
.
reshape
(
cross_entropy_loss
,
shape
=
[
-
1
])
top_k_pixels
=
tf
.
cast
(
self
.
_top_k_percent_pixels
*
tf
.
cast
(
tf
.
size
(
cross_entropy_loss
),
tf
.
float32
),
tf
.
int32
)
top_k_losses
,
_
=
tf
.
math
.
top_k
(
cross_entropy_loss
,
k
=
top_k_pixels
,
sorted
=
True
)
normalizer
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
top_k_losses
,
0.0
),
tf
.
float32
))
+
EPSILON
loss
=
tf
.
reduce_sum
(
top_k_losses
)
/
normalizer
return
loss
class
CenterHeatmapLoss
:
def
__init__
(
self
):
self
.
_loss_fn
=
tf
.
losses
.
mean_squared_error
def
__call__
(
self
,
logits
,
labels
,
sample_weight
=
None
):
_
,
height
,
width
,
_
=
labels
.
get_shape
().
as_list
()
logits
=
tf
.
image
.
resize
(
logits
,
size
=
[
height
,
width
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
loss
=
self
.
_loss_fn
(
y_true
=
labels
,
y_pred
=
logits
)
loss
=
self
.
_loss_fn
(
y_true
=
labels
,
y_pred
=
logits
)
return
tf_utils
.
safe_mean
(
loss
*
sample_weight
)
if
sample_weight
is
not
None
:
loss
*=
sample_weight
return
tf_utils
.
safe_mean
(
loss
)
class
CenterOffsetLoss
:
def
__init__
(
self
):
self
.
_loss_fn
=
tf
.
losses
.
mean_absolute_error
def
__call__
(
self
,
logits
,
labels
,
sample_weight
=
None
):
_
,
height
,
width
,
_
=
labels
.
get_shape
().
as_list
()
logits
=
mask_ops
.
resize_and_rescale_offsets
(
logits
,
target_size
=
[
height
,
width
])
loss
=
self
.
_loss_fn
(
y_true
=
labels
,
y_pred
=
logits
)
class
CenterHeatmapLoss
(
CenterLoss
):
if
sample_weight
is
not
None
:
def
__init__
(
self
,
use_groundtruth_dimension
):
loss
*=
sample_weight
super
(
CenterHeatmapLoss
,
self
).
__init__
(
use_groundtruth_dimension
=
use_groundtruth_dimension
,
loss_type
=
'mse'
)
class
CenterOffsetLoss
(
CenterLoss
):
return
tf_utils
.
safe_mean
(
loss
)
def
__init__
(
self
,
use_groundtruth_dimension
):
super
(
CenterOffsetLoss
,
self
).
__init__
(
use_groundtruth_dimension
=
use_groundtruth_dimension
,
loss_type
=
'mae'
)
official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_deeplab.py
View file @
6e0d65cb
...
@@ -28,7 +28,6 @@ from official.vision.beta.projects.panoptic_maskrcnn.losses import panoptic_deep
...
@@ -28,7 +28,6 @@ from official.vision.beta.projects.panoptic_maskrcnn.losses import panoptic_deep
from
official.vision.dataloaders
import
input_reader_factory
from
official.vision.dataloaders
import
input_reader_factory
from
official.vision.evaluation
import
panoptic_quality_evaluator
from
official.vision.evaluation
import
panoptic_quality_evaluator
from
official.vision.evaluation
import
segmentation_metrics
from
official.vision.evaluation
import
segmentation_metrics
from
official.vision.losses
import
segmentation_losses
@
task_factory
.
register_task_cls
(
exp_cfg
.
PanopticDeeplabTask
)
@
task_factory
.
register_task_cls
(
exp_cfg
.
PanopticDeeplabTask
)
...
@@ -131,28 +130,29 @@ class PanopticDeeplabTask(base_task.Task):
...
@@ -131,28 +130,29 @@ class PanopticDeeplabTask(base_task.Task):
The total loss tensor.
The total loss tensor.
"""
"""
loss_config
=
self
.
_task_config
.
losses
loss_config
=
self
.
_task_config
.
losses
segmentation_loss_fn
=
segmentation_losses
.
Segmentation
Loss
(
segmentation_loss_fn
=
panoptic_deeplab_losses
.
WeightedBootstrappedCrossEntropy
Loss
(
loss_config
.
label_smoothing
,
loss_config
.
label_smoothing
,
loss_config
.
class_weights
,
loss_config
.
class_weights
,
loss_config
.
ignore_label
,
loss_config
.
ignore_label
,
use_groundtruth_dimension
=
loss_config
.
use_groundtruth_dimension
,
top_k_percent_pixels
=
loss_config
.
top_k_percent_pixels
)
top_k_percent_pixels
=
loss_config
.
top_k_percent_pixels
)
instance_center_heatmap_loss_fn
=
panoptic_deeplab_losses
.
CenterHeatmapLoss
(
instance_center_heatmap_loss_fn
=
panoptic_deeplab_losses
.
CenterHeatmapLoss
()
use_groundtruth_dimension
=
loss_config
.
use_groundtruth_dimension
)
instance_center_offset_loss_fn
=
panoptic_deeplab_losses
.
CenterOffsetLoss
()
instance_center_offset_loss_fn
=
panoptic_deeplab_losses
.
CenterOffsetLoss
(
use_groundtruth_dimension
=
loss_config
.
use_groundtruth_dimension
)
segmentation_loss
=
segmentation_loss_fn
(
model_outputs
[
'segmentation_outputs'
],
labels
[
'category_mask'
])
semantic_weights
=
tf
.
cast
(
labels
[
'semantic_weights'
],
dtype
=
model_outputs
[
'instance_centers_heatmap'
].
dtype
)
things_mask
=
tf
.
cast
(
things_mask
=
tf
.
cast
(
tf
.
squeeze
(
labels
[
'things_mask'
],
axis
=
3
),
labels
[
'things_mask'
],
dtype
=
model_outputs
[
'instance_centers_heatmap'
].
dtype
)
dtype
=
model_outputs
[
'instance_centers_heatmap'
].
dtype
)
valid_mask
=
tf
.
cast
(
valid_mask
=
tf
.
cast
(
tf
.
squeeze
(
labels
[
'valid_mask'
],
axis
=
3
),
labels
[
'valid_mask'
],
dtype
=
model_outputs
[
'instance_centers_heatmap'
].
dtype
)
dtype
=
model_outputs
[
'instance_centers_heatmap'
].
dtype
)
segmentation_loss
=
segmentation_loss_fn
(
model_outputs
[
'segmentation_outputs'
],
labels
[
'category_mask'
],
sample_weight
=
semantic_weights
)
instance_center_heatmap_loss
=
instance_center_heatmap_loss_fn
(
instance_center_heatmap_loss
=
instance_center_heatmap_loss_fn
(
model_outputs
[
'instance_centers_heatmap'
],
model_outputs
[
'instance_centers_heatmap'
],
labels
[
'instance_centers_heatmap'
],
labels
[
'instance_centers_heatmap'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment