Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
69231ce9
Commit
69231ce9
authored
Sep 21, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 476005369
parent
37e76715
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
62 deletions
+139
-62
official/projects/panoptic/configs/panoptic_maskrcnn.py
official/projects/panoptic/configs/panoptic_maskrcnn.py
+1
-0
official/projects/panoptic/tasks/panoptic_maskrcnn.py
official/projects/panoptic/tasks/panoptic_maskrcnn.py
+1
-0
official/vision/configs/semantic_segmentation.py
official/vision/configs/semantic_segmentation.py
+26
-15
official/vision/dataloaders/segmentation_input.py
official/vision/dataloaders/segmentation_input.py
+46
-22
official/vision/dataloaders/utils.py
official/vision/dataloaders/utils.py
+17
-0
official/vision/losses/segmentation_losses.py
official/vision/losses/segmentation_losses.py
+26
-6
official/vision/tasks/semantic_segmentation.py
official/vision/tasks/semantic_segmentation.py
+22
-19
No files found.
official/projects/panoptic/configs/panoptic_maskrcnn.py
View file @
69231ce9
...
...
@@ -109,6 +109,7 @@ class Losses(maskrcnn.Losses):
"""Panoptic Mask R-CNN loss config."""
semantic_segmentation_label_smoothing
:
float
=
0.0
semantic_segmentation_ignore_label
:
int
=
255
semantic_segmentation_gt_is_matting_map
:
bool
=
False
semantic_segmentation_class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
semantic_segmentation_use_groundtruth_dimension
:
bool
=
True
...
...
official/projects/panoptic/tasks/panoptic_maskrcnn.py
View file @
69231ce9
...
...
@@ -181,6 +181,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
label_smoothing
=
params
.
semantic_segmentation_label_smoothing
,
class_weights
=
params
.
semantic_segmentation_class_weights
,
ignore_label
=
params
.
semantic_segmentation_ignore_label
,
gt_is_matting_map
=
params
.
semantic_segmentation_gt_is_matting_map
,
use_groundtruth_dimension
=
use_groundtruth_dimension
,
top_k_percent_pixels
=
params
.
semantic_segmentation_top_k_percent_pixels
)
...
...
official/vision/configs/semantic_segmentation.py
View file @
69231ce9
...
...
@@ -104,6 +104,7 @@ class Losses(hyperparams.Config):
loss_weight
:
float
=
1.0
label_smoothing
:
float
=
0.0
ignore_label
:
int
=
255
gt_is_matting_map
:
bool
=
False
class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
l2_weight_decay
:
float
=
0.0
use_groundtruth_dimension
:
bool
=
True
...
...
@@ -132,8 +133,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
evaluation
:
Evaluation
=
Evaluation
()
train_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
eval_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
eval_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
...
...
@@ -151,6 +151,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig:
'task.validation_data.is_training != None'
])
# PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES
=
10582
PASCAL_VAL_EXAMPLES
=
1449
...
...
@@ -174,11 +175,15 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
num_classes
=
21
,
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
multigrid
=
multigrid
,
stem_type
=
stem_type
)),
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
multigrid
=
multigrid
,
stem_type
=
stem_type
)),
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)),
head
=
SegmentationHead
(
level
=
level
,
num_convs
=
0
),
norm_activation
=
common
.
NormActivation
(
...
...
@@ -262,9 +267,12 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
num_classes
=
21
,
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
...
...
@@ -356,8 +364,7 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
decoder
=
decoders
.
Decoder
(
type
=
'fpn'
,
fpn
=
decoders
.
FPN
()),
head
=
SegmentationHead
(
level
=
3
,
num_convs
=
3
),
norm_activation
=
common
.
NormActivation
(
activation
=
'swish'
,
use_sync_bn
=
True
)),
activation
=
'swish'
,
use_sync_bn
=
True
)),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
...
...
@@ -530,13 +537,17 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
num_classes
=
19
,
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
,
level
=
level
,
dilation_rates
=
aspp_dilation_rates
,
pool_kernel_size
=
[
512
,
1024
])),
head
=
SegmentationHead
(
level
=
level
,
...
...
official/vision/dataloaders/segmentation_input.py
View file @
69231ce9
...
...
@@ -17,6 +17,7 @@
import
tensorflow
as
tf
from
official.vision.dataloaders
import
decoder
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
preprocess_ops
...
...
@@ -25,26 +26,29 @@ class Decoder(decoder.Decoder):
def
__init__
(
self
):
self
.
_keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/height'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/width'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/height'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/width'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/segmentation/class/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
)
}
def
decode
(
self
,
serialized_example
):
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
class
Parser
(
parser
.
Parser
):
"""Parser to parse an image and its annotations into a dictionary of tensors.
"""
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
output_size
,
crop_size
=
None
,
resize_eval_groundtruth
=
True
,
gt_is_matting_map
=
False
,
groundtruth_padded_size
=
None
,
ignore_label
=
255
,
aug_rand_hflip
=
False
,
...
...
@@ -63,13 +67,16 @@ class Parser(parser.Parser):
original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size.
gt_is_matting_map: `bool`, if True, the expected mask is in the range
between 0 and 255. The parser will normalize the value of the mask into
the range between 0 and 1.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
resize_eval_groundtruth is set to False, the groundtruth masks are
padded to this size.
ignore_label: `int` the pixel with ignore label will not used for training
and evaluation.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal
flip.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal
flip.
preserve_aspect_ratio: `bool`, if True, the aspect ratio is preserved,
otherwise, the image is resized to output_size.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
...
@@ -84,6 +91,7 @@ class Parser(parser.Parser):
if
(
not
resize_eval_groundtruth
)
and
(
groundtruth_padded_size
is
None
):
raise
ValueError
(
'groundtruth_padded_size ([height, width]) needs to be'
'specified when resize_eval_groundtruth is False.'
)
self
.
_gt_is_matting_map
=
gt_is_matting_map
self
.
_groundtruth_padded_size
=
groundtruth_padded_size
self
.
_ignore_label
=
ignore_label
self
.
_preserve_aspect_ratio
=
preserve_aspect_ratio
...
...
@@ -99,8 +107,8 @@ class Parser(parser.Parser):
def
_prepare_image_and_label
(
self
,
data
):
"""Prepare normalized image and label."""
image
=
tf
.
io
.
decode_image
(
data
[
'image/encoded'
],
channels
=
3
)
label
=
tf
.
io
.
decode_image
(
data
[
'image/segmentation/class/encoded'
],
channels
=
1
)
label
=
tf
.
io
.
decode_image
(
data
[
'image/segmentation/class/encoded'
],
channels
=
1
)
height
=
data
[
'image/height'
]
width
=
data
[
'image/width'
]
image
=
tf
.
reshape
(
image
,
(
height
,
width
,
3
))
...
...
@@ -122,6 +130,16 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation."""
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
# Normalize the label into the range of 0 and 1 for matting groundtruth.
# Note that the input groundtruth labels must be 0 to 255, and do not
# contain ignore_label. For gt_is_matting_map case, ignore_label is only
# used for padding the labels.
if
self
.
_gt_is_matting_map
:
scale
=
tf
.
constant
(
255.0
,
dtype
=
tf
.
float32
)
scale
=
tf
.
expand_dims
(
scale
,
axis
=
0
)
scale
=
tf
.
expand_dims
(
scale
,
axis
=
0
)
label
=
tf
.
cast
(
label
,
tf
.
float32
)
/
scale
if
self
.
_crop_size
:
label
=
tf
.
reshape
(
label
,
[
data
[
'image/height'
],
data
[
'image/width'
],
1
])
...
...
@@ -132,8 +150,7 @@ class Parser(parser.Parser):
label
=
tf
.
image
.
resize
(
label
,
self
.
_output_size
,
method
=
'nearest'
)
image_mask
=
tf
.
concat
([
image
,
label
],
axis
=
2
)
image_mask_crop
=
tf
.
image
.
random_crop
(
image_mask
,
self
.
_crop_size
+
[
4
])
image_mask_crop
=
tf
.
image
.
random_crop
(
image_mask
,
self
.
_crop_size
+
[
4
])
image
=
image_mask_crop
[:,
:,
:
-
1
]
label
=
tf
.
reshape
(
image_mask_crop
[:,
:,
-
1
],
[
1
]
+
self
.
_crop_size
)
...
...
@@ -159,13 +176,14 @@ class Parser(parser.Parser):
# The label is first offset by +1 and then padded with 0.
label
+=
1
label
=
tf
.
expand_dims
(
label
,
axis
=
3
)
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
,
image_scale
,
train_image_size
,
offset
)
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
,
image_scale
,
train_image_size
,
offset
)
label
-=
1
label
=
tf
.
where
(
tf
.
equal
(
label
,
-
1
),
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
label
=
tf
.
where
(
tf
.
equal
(
label
,
-
1
),
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
label
=
tf
.
squeeze
(
label
,
axis
=
0
)
valid_mask
=
tf
.
not_equal
(
label
,
self
.
_ignore_label
)
labels
=
{
'masks'
:
label
,
'valid_masks'
:
valid_mask
,
...
...
@@ -180,6 +198,12 @@ class Parser(parser.Parser):
def
_parse_eval_data
(
self
,
data
):
"""Parses data for training and evaluation."""
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
# Binarize mask if groundtruth is a matting map
if
self
.
_gt_is_matting_map
:
label
=
tf
.
divide
(
tf
.
cast
(
label
,
dtype
=
tf
.
float32
),
255.0
)
label
=
utils
.
binarize_matting_map
(
label
)
# The label is first offset by +1 and then padded with 0.
label
+=
1
label
=
tf
.
expand_dims
(
label
,
axis
=
3
)
...
...
@@ -196,13 +220,13 @@ class Parser(parser.Parser):
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
,
image_scale
,
self
.
_output_size
,
offset
)
else
:
label
=
tf
.
image
.
pad_to_bounding_box
(
label
,
0
,
0
,
self
.
_groundtruth_padded_size
[
0
],
self
.
_groundtruth_padded_size
[
1
])
label
=
tf
.
image
.
pad_to_bounding_box
(
label
,
0
,
0
,
self
.
_groundtruth_padded_size
[
0
],
self
.
_groundtruth_padded_size
[
1
])
label
-=
1
label
=
tf
.
where
(
tf
.
equal
(
label
,
-
1
),
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
label
=
tf
.
where
(
tf
.
equal
(
label
,
-
1
),
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
label
=
tf
.
squeeze
(
label
,
axis
=
0
)
valid_mask
=
tf
.
not_equal
(
label
,
self
.
_ignore_label
)
...
...
official/vision/dataloaders/utils.py
View file @
69231ce9
...
...
@@ -67,3 +67,20 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor],
groundtruths
[
'attributes'
][
k
]
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
v
,
size
,
-
1
)
return
groundtruths
def
binarize_matting_map
(
matting_map
:
tf
.
Tensor
,
threshold
:
float
=
0.5
)
->
tf
.
Tensor
:
"""Binarizes a matting map.
If the matting_map value is above a threshold, set it as 1 otherwise 0. The
binarization is done for every element in the matting_map.
Args:
matting_map: The groundtruth in the matting map format.
threshold: The threshold used to binarize the matting map.
Returns:
The binarized labels (0 for BG, 1 for FG) as tf.float32.
"""
return
tf
.
cast
(
tf
.
greater
(
matting_map
,
threshold
),
tf
.
float32
)
official/vision/losses/segmentation_losses.py
View file @
69231ce9
...
...
@@ -17,6 +17,7 @@
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.dataloaders
import
utils
EPSILON
=
1e-5
...
...
@@ -28,6 +29,7 @@ class SegmentationLoss:
label_smoothing
,
class_weights
,
ignore_label
,
gt_is_matting_map
,
use_groundtruth_dimension
,
top_k_percent_pixels
=
1.0
):
"""Initializes `SegmentationLoss`.
...
...
@@ -37,6 +39,8 @@ class SegmentationLoss:
spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
...
...
@@ -46,6 +50,7 @@ class SegmentationLoss:
self
.
_label_smoothing
=
label_smoothing
self
.
_class_weights
=
class_weights
self
.
_ignore_label
=
ignore_label
self
.
_gt_is_matting_map
=
gt_is_matting_map
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
...
...
@@ -73,8 +78,12 @@ class SegmentationLoss:
labels
,
(
height
,
width
),
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
# Do not need to cast into int32 if it is a matting map
if
not
self
.
_gt_is_matting_map
:
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
valid_mask
=
tf
.
not_equal
(
labels
,
self
.
_ignore_label
)
cross_entropy_loss
=
self
.
compute_pixelwise_loss
(
labels
,
logits
,
valid_mask
,
**
kwargs
)
...
...
@@ -119,6 +128,12 @@ class SegmentationLoss:
'Length of class_weights should be {}'
.
format
(
num_classes
))
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=-
1
)
# If groundtruth is matting map, binarize the value to create the weight
# mask
if
self
.
_gt_is_matting_map
:
labels
=
tf
.
cast
(
utils
.
binarize_matting_map
(
labels
),
tf
.
int32
)
weight_mask
=
tf
.
einsum
(
'...y,y->...'
,
tf
.
one_hot
(
tf
.
squeeze
(
labels
,
axis
=-
1
),
num_classes
,
dtype
=
tf
.
float32
),
...
...
@@ -131,8 +146,9 @@ class SegmentationLoss:
This method can be overridden in subclasses for customizing loss function.
Args:
labels: An int32 tensor in shape (batch_size, height, width, 1), which is
the label map of the ground truth.
labels: If groundtruth mask is not matting map, an int32 tensor which is
the label map of the groundtruth. If groundtruth mask is matting map,
an float32 tensor. The shape is always (batch_size, height, width, 1).
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
**unused_kwargs: Unused keyword arguments.
...
...
@@ -140,10 +156,14 @@ class SegmentationLoss:
Returns:
A float tensor in shape (batch_size, height, width, num_classes).
"""
labels
=
tf
.
squeeze
(
labels
,
axis
=-
1
)
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
return
onehot_labels
*
(
if
self
.
_gt_is_matting_map
:
train_labels
=
tf
.
concat
([
1
-
labels
,
labels
],
axis
=-
1
)
else
:
labels
=
tf
.
squeeze
(
labels
,
axis
=-
1
)
train_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
return
train_labels
*
(
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
def
aggregate_loss
(
self
,
pixelwise_loss
,
valid_mask
):
...
...
official/vision/tasks/semantic_segmentation.py
View file @
69231ce9
...
...
@@ -35,15 +35,16 @@ class SemanticSegmentationTask(base_task.Task):
def
build_model
(
self
):
"""Builds segmentation model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_segmentation_model
(
input_specs
=
input_specs
,
...
...
@@ -85,6 +86,7 @@ class SemanticSegmentationTask(base_task.Task):
"""Builds classification input."""
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
gt_is_matting_map
=
self
.
task_config
.
losses
.
gt_is_matting_map
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
...
...
@@ -96,6 +98,7 @@ class SemanticSegmentationTask(base_task.Task):
crop_size
=
params
.
crop_size
,
ignore_label
=
ignore_label
,
resize_eval_groundtruth
=
params
.
resize_eval_groundtruth
,
gt_is_matting_map
=
gt_is_matting_map
,
groundtruth_padded_size
=
params
.
groundtruth_padded_size
,
aug_scale_min
=
params
.
aug_scale_min
,
aug_scale_max
=
params
.
aug_scale_max
,
...
...
@@ -132,6 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
loss_params
.
label_smoothing
,
loss_params
.
class_weights
,
loss_params
.
ignore_label
,
loss_params
.
gt_is_matting_map
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
)
...
...
@@ -140,10 +144,9 @@ class SemanticSegmentationTask(base_task.Task):
if
'mask_scores'
in
model_outputs
:
mask_scoring_loss_fn
=
segmentation_losses
.
MaskScoringLoss
(
loss_params
.
ignore_label
)
total_loss
+=
mask_scoring_loss_fn
(
model_outputs
[
'mask_scores'
],
model_outputs
[
'logits'
],
labels
[
'masks'
])
total_loss
+=
mask_scoring_loss_fn
(
model_outputs
[
'mask_scores'
],
model_outputs
[
'logits'
],
labels
[
'masks'
])
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
...
...
@@ -178,11 +181,12 @@ class SemanticSegmentationTask(base_task.Task):
"""Gets streaming metrics for training/validation."""
metrics
=
[]
if
training
and
self
.
task_config
.
evaluation
.
report_train_mean_iou
:
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
name
=
'mean_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
False
,
dtype
=
tf
.
float32
))
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
name
=
'mean_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
False
,
dtype
=
tf
.
float32
))
if
self
.
task_config
.
model
.
get
(
'mask_scoring_head'
):
metrics
.
append
(
tf
.
keras
.
metrics
.
MeanSquaredError
(
name
=
'mask_scores_mse'
))
...
...
@@ -202,8 +206,8 @@ class SemanticSegmentationTask(base_task.Task):
tf
.
keras
.
metrics
.
MeanSquaredError
(
name
=
'mask_scores_mse'
))
# Update state on CPU if TPUStrategy due to dynamic resizing.
self
.
_process_iou_metric_on_cpu
=
isinstance
(
tf
.
distribute
.
get_strategy
(),
tf
.
distribute
.
TPUStrategy
)
self
.
_process_iou_metric_on_cpu
=
isinstance
(
tf
.
distribute
.
get_strategy
(),
tf
.
distribute
.
TPUStrategy
)
return
metrics
...
...
@@ -238,8 +242,7 @@ class SemanticSegmentationTask(base_task.Task):
outputs
=
{
'logits'
:
outputs
}
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
...
...
@@ -296,8 +299,8 @@ class SemanticSegmentationTask(base_task.Task):
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
if
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
:
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
else
:
loss
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment