Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
69231ce9
Commit
69231ce9
authored
Sep 21, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 476005369
parent
37e76715
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
62 deletions
+139
-62
official/projects/panoptic/configs/panoptic_maskrcnn.py
official/projects/panoptic/configs/panoptic_maskrcnn.py
+1
-0
official/projects/panoptic/tasks/panoptic_maskrcnn.py
official/projects/panoptic/tasks/panoptic_maskrcnn.py
+1
-0
official/vision/configs/semantic_segmentation.py
official/vision/configs/semantic_segmentation.py
+26
-15
official/vision/dataloaders/segmentation_input.py
official/vision/dataloaders/segmentation_input.py
+46
-22
official/vision/dataloaders/utils.py
official/vision/dataloaders/utils.py
+17
-0
official/vision/losses/segmentation_losses.py
official/vision/losses/segmentation_losses.py
+26
-6
official/vision/tasks/semantic_segmentation.py
official/vision/tasks/semantic_segmentation.py
+22
-19
No files found.
official/projects/panoptic/configs/panoptic_maskrcnn.py
View file @
69231ce9
...
@@ -109,6 +109,7 @@ class Losses(maskrcnn.Losses):
...
@@ -109,6 +109,7 @@ class Losses(maskrcnn.Losses):
"""Panoptic Mask R-CNN loss config."""
"""Panoptic Mask R-CNN loss config."""
semantic_segmentation_label_smoothing
:
float
=
0.0
semantic_segmentation_label_smoothing
:
float
=
0.0
semantic_segmentation_ignore_label
:
int
=
255
semantic_segmentation_ignore_label
:
int
=
255
semantic_segmentation_gt_is_matting_map
:
bool
=
False
semantic_segmentation_class_weights
:
List
[
float
]
=
dataclasses
.
field
(
semantic_segmentation_class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
default_factory
=
list
)
semantic_segmentation_use_groundtruth_dimension
:
bool
=
True
semantic_segmentation_use_groundtruth_dimension
:
bool
=
True
...
...
official/projects/panoptic/tasks/panoptic_maskrcnn.py
View file @
69231ce9
...
@@ -181,6 +181,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
...
@@ -181,6 +181,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
label_smoothing
=
params
.
semantic_segmentation_label_smoothing
,
label_smoothing
=
params
.
semantic_segmentation_label_smoothing
,
class_weights
=
params
.
semantic_segmentation_class_weights
,
class_weights
=
params
.
semantic_segmentation_class_weights
,
ignore_label
=
params
.
semantic_segmentation_ignore_label
,
ignore_label
=
params
.
semantic_segmentation_ignore_label
,
gt_is_matting_map
=
params
.
semantic_segmentation_gt_is_matting_map
,
use_groundtruth_dimension
=
use_groundtruth_dimension
,
use_groundtruth_dimension
=
use_groundtruth_dimension
,
top_k_percent_pixels
=
params
.
semantic_segmentation_top_k_percent_pixels
)
top_k_percent_pixels
=
params
.
semantic_segmentation_top_k_percent_pixels
)
...
...
official/vision/configs/semantic_segmentation.py
View file @
69231ce9
...
@@ -104,6 +104,7 @@ class Losses(hyperparams.Config):
...
@@ -104,6 +104,7 @@ class Losses(hyperparams.Config):
loss_weight
:
float
=
1.0
loss_weight
:
float
=
1.0
label_smoothing
:
float
=
0.0
label_smoothing
:
float
=
0.0
ignore_label
:
int
=
255
ignore_label
:
int
=
255
gt_is_matting_map
:
bool
=
False
class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
l2_weight_decay
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
use_groundtruth_dimension
:
bool
=
True
use_groundtruth_dimension
:
bool
=
True
...
@@ -132,8 +133,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
...
@@ -132,8 +133,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
evaluation
:
Evaluation
=
Evaluation
()
evaluation
:
Evaluation
=
Evaluation
()
train_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
train_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
default_factory
=
list
)
eval_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
eval_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
default_factory
=
list
)
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
...
@@ -151,6 +151,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig:
...
@@ -151,6 +151,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig:
'task.validation_data.is_training != None'
'task.validation_data.is_training != None'
])
])
# PASCAL VOC 2012 Dataset
# PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES
=
10582
PASCAL_TRAIN_EXAMPLES
=
10582
PASCAL_VAL_EXAMPLES
=
1449
PASCAL_VAL_EXAMPLES
=
1449
...
@@ -174,11 +175,15 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
...
@@ -174,11 +175,15 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
num_classes
=
21
,
num_classes
=
21
,
input_size
=
[
None
,
None
,
3
],
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
type
=
'dilated_resnet'
,
model_id
=
101
,
output_stride
=
output_stride
,
dilated_resnet
=
backbones
.
DilatedResNet
(
multigrid
=
multigrid
,
stem_type
=
stem_type
)),
model_id
=
101
,
output_stride
=
output_stride
,
multigrid
=
multigrid
,
stem_type
=
stem_type
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)),
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)),
head
=
SegmentationHead
(
level
=
level
,
num_convs
=
0
),
head
=
SegmentationHead
(
level
=
level
,
num_convs
=
0
),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
...
@@ -262,9 +267,12 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
...
@@ -262,9 +267,12 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
num_classes
=
21
,
num_classes
=
21
,
input_size
=
[
None
,
None
,
3
],
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
type
=
'dilated_resnet'
,
model_id
=
101
,
output_stride
=
output_stride
,
dilated_resnet
=
backbones
.
DilatedResNet
(
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
aspp
=
decoders
.
ASPP
(
...
@@ -356,8 +364,7 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
...
@@ -356,8 +364,7 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
decoder
=
decoders
.
Decoder
(
type
=
'fpn'
,
fpn
=
decoders
.
FPN
()),
decoder
=
decoders
.
Decoder
(
type
=
'fpn'
,
fpn
=
decoders
.
FPN
()),
head
=
SegmentationHead
(
level
=
3
,
num_convs
=
3
),
head
=
SegmentationHead
(
level
=
3
,
num_convs
=
3
),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
activation
=
'swish'
,
activation
=
'swish'
,
use_sync_bn
=
True
)),
use_sync_bn
=
True
)),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
input_path
=
os
.
path
.
join
(
PASCAL_INPUT_PATH_BASE
,
'train_aug*'
),
...
@@ -530,13 +537,17 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
...
@@ -530,13 +537,17 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
num_classes
=
19
,
num_classes
=
19
,
input_size
=
[
None
,
None
,
3
],
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
type
=
'dilated_resnet'
,
model_id
=
101
,
output_stride
=
output_stride
,
dilated_resnet
=
backbones
.
DilatedResNet
(
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
aspp
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
,
level
=
level
,
dilation_rates
=
aspp_dilation_rates
,
pool_kernel_size
=
[
512
,
1024
])),
pool_kernel_size
=
[
512
,
1024
])),
head
=
SegmentationHead
(
head
=
SegmentationHead
(
level
=
level
,
level
=
level
,
...
...
official/vision/dataloaders/segmentation_input.py
View file @
69231ce9
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.dataloaders
import
decoder
from
official.vision.dataloaders
import
decoder
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
preprocess_ops
from
official.vision.ops
import
preprocess_ops
...
@@ -25,26 +26,29 @@ class Decoder(decoder.Decoder):
...
@@ -25,26 +26,29 @@ class Decoder(decoder.Decoder):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_keys_to_features
=
{
self
.
_keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/encoded'
:
'image/height'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/width'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/height'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/width'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
'image/segmentation/class/encoded'
:
'image/segmentation/class/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
)
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
)
}
}
def
decode
(
self
,
serialized_example
):
def
decode
(
self
,
serialized_example
):
return
tf
.
io
.
parse_single_example
(
return
tf
.
io
.
parse_single_example
(
serialized_example
,
serialized_example
,
self
.
_keys_to_features
)
self
.
_keys_to_features
)
class
Parser
(
parser
.
Parser
):
class
Parser
(
parser
.
Parser
):
"""Parser to parse an image and its annotations into a dictionary of tensors.
"""Parser to parse an image and its annotations into a dictionary of tensors."""
"""
def
__init__
(
self
,
def
__init__
(
self
,
output_size
,
output_size
,
crop_size
=
None
,
crop_size
=
None
,
resize_eval_groundtruth
=
True
,
resize_eval_groundtruth
=
True
,
gt_is_matting_map
=
False
,
groundtruth_padded_size
=
None
,
groundtruth_padded_size
=
None
,
ignore_label
=
255
,
ignore_label
=
255
,
aug_rand_hflip
=
False
,
aug_rand_hflip
=
False
,
...
@@ -63,13 +67,16 @@ class Parser(parser.Parser):
...
@@ -63,13 +67,16 @@ class Parser(parser.Parser):
original image sizes.
original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size.
resized to output_size.
gt_is_matting_map: `bool`, if True, the expected mask is in the range
between 0 and 255. The parser will normalize the value of the mask into
the range between 0 and 1.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
resize_eval_groundtruth is set to False, the groundtruth masks are
resize_eval_groundtruth is set to False, the groundtruth masks are
padded to this size.
padded to this size.
ignore_label: `int` the pixel with ignore label will not used for training
ignore_label: `int` the pixel with ignore label will not used for training
and evaluation.
and evaluation.
aug_rand_hflip: `bool`, if True, augment training with random
aug_rand_hflip: `bool`, if True, augment training with random
horizontal
horizontal
flip.
flip.
preserve_aspect_ratio: `bool`, if True, the aspect ratio is preserved,
preserve_aspect_ratio: `bool`, if True, the aspect ratio is preserved,
otherwise, the image is resized to output_size.
otherwise, the image is resized to output_size.
aug_scale_min: `float`, the minimum scale applied to `output_size` for
aug_scale_min: `float`, the minimum scale applied to `output_size` for
...
@@ -84,6 +91,7 @@ class Parser(parser.Parser):
...
@@ -84,6 +91,7 @@ class Parser(parser.Parser):
if
(
not
resize_eval_groundtruth
)
and
(
groundtruth_padded_size
is
None
):
if
(
not
resize_eval_groundtruth
)
and
(
groundtruth_padded_size
is
None
):
raise
ValueError
(
'groundtruth_padded_size ([height, width]) needs to be'
raise
ValueError
(
'groundtruth_padded_size ([height, width]) needs to be'
'specified when resize_eval_groundtruth is False.'
)
'specified when resize_eval_groundtruth is False.'
)
self
.
_gt_is_matting_map
=
gt_is_matting_map
self
.
_groundtruth_padded_size
=
groundtruth_padded_size
self
.
_groundtruth_padded_size
=
groundtruth_padded_size
self
.
_ignore_label
=
ignore_label
self
.
_ignore_label
=
ignore_label
self
.
_preserve_aspect_ratio
=
preserve_aspect_ratio
self
.
_preserve_aspect_ratio
=
preserve_aspect_ratio
...
@@ -99,8 +107,8 @@ class Parser(parser.Parser):
...
@@ -99,8 +107,8 @@ class Parser(parser.Parser):
def
_prepare_image_and_label
(
self
,
data
):
def
_prepare_image_and_label
(
self
,
data
):
"""Prepare normalized image and label."""
"""Prepare normalized image and label."""
image
=
tf
.
io
.
decode_image
(
data
[
'image/encoded'
],
channels
=
3
)
image
=
tf
.
io
.
decode_image
(
data
[
'image/encoded'
],
channels
=
3
)
label
=
tf
.
io
.
decode_image
(
data
[
'image/segmentation/class/encoded'
],
label
=
tf
.
io
.
decode_image
(
channels
=
1
)
data
[
'image/segmentation/class/encoded'
],
channels
=
1
)
height
=
data
[
'image/height'
]
height
=
data
[
'image/height'
]
width
=
data
[
'image/width'
]
width
=
data
[
'image/width'
]
image
=
tf
.
reshape
(
image
,
(
height
,
width
,
3
))
image
=
tf
.
reshape
(
image
,
(
height
,
width
,
3
))
...
@@ -122,6 +130,16 @@ class Parser(parser.Parser):
...
@@ -122,6 +130,16 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation."""
"""Parses data for training and evaluation."""
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
# Normalize the label into the range of 0 and 1 for matting groundtruth.
# Note that the input groundtruth labels must be 0 to 255, and do not
# contain ignore_label. For gt_is_matting_map case, ignore_label is only
# used for padding the labels.
if
self
.
_gt_is_matting_map
:
scale
=
tf
.
constant
(
255.0
,
dtype
=
tf
.
float32
)
scale
=
tf
.
expand_dims
(
scale
,
axis
=
0
)
scale
=
tf
.
expand_dims
(
scale
,
axis
=
0
)
label
=
tf
.
cast
(
label
,
tf
.
float32
)
/
scale
if
self
.
_crop_size
:
if
self
.
_crop_size
:
label
=
tf
.
reshape
(
label
,
[
data
[
'image/height'
],
data
[
'image/width'
],
1
])
label
=
tf
.
reshape
(
label
,
[
data
[
'image/height'
],
data
[
'image/width'
],
1
])
...
@@ -132,8 +150,7 @@ class Parser(parser.Parser):
...
@@ -132,8 +150,7 @@ class Parser(parser.Parser):
label
=
tf
.
image
.
resize
(
label
,
self
.
_output_size
,
method
=
'nearest'
)
label
=
tf
.
image
.
resize
(
label
,
self
.
_output_size
,
method
=
'nearest'
)
image_mask
=
tf
.
concat
([
image
,
label
],
axis
=
2
)
image_mask
=
tf
.
concat
([
image
,
label
],
axis
=
2
)
image_mask_crop
=
tf
.
image
.
random_crop
(
image_mask
,
image_mask_crop
=
tf
.
image
.
random_crop
(
image_mask
,
self
.
_crop_size
+
[
4
])
self
.
_crop_size
+
[
4
])
image
=
image_mask_crop
[:,
:,
:
-
1
]
image
=
image_mask_crop
[:,
:,
:
-
1
]
label
=
tf
.
reshape
(
image_mask_crop
[:,
:,
-
1
],
[
1
]
+
self
.
_crop_size
)
label
=
tf
.
reshape
(
image_mask_crop
[:,
:,
-
1
],
[
1
]
+
self
.
_crop_size
)
...
@@ -159,13 +176,14 @@ class Parser(parser.Parser):
...
@@ -159,13 +176,14 @@ class Parser(parser.Parser):
# The label is first offset by +1 and then padded with 0.
# The label is first offset by +1 and then padded with 0.
label
+=
1
label
+=
1
label
=
tf
.
expand_dims
(
label
,
axis
=
3
)
label
=
tf
.
expand_dims
(
label
,
axis
=
3
)
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
,
image_scale
,
label
,
image_scale
,
train_image_size
,
offset
)
train_image_size
,
offset
)
label
-=
1
label
-=
1
label
=
tf
.
where
(
tf
.
equal
(
label
,
-
1
),
label
=
tf
.
where
(
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
tf
.
equal
(
label
,
-
1
),
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
label
=
tf
.
squeeze
(
label
,
axis
=
0
)
label
=
tf
.
squeeze
(
label
,
axis
=
0
)
valid_mask
=
tf
.
not_equal
(
label
,
self
.
_ignore_label
)
valid_mask
=
tf
.
not_equal
(
label
,
self
.
_ignore_label
)
labels
=
{
labels
=
{
'masks'
:
label
,
'masks'
:
label
,
'valid_masks'
:
valid_mask
,
'valid_masks'
:
valid_mask
,
...
@@ -180,6 +198,12 @@ class Parser(parser.Parser):
...
@@ -180,6 +198,12 @@ class Parser(parser.Parser):
def
_parse_eval_data
(
self
,
data
):
def
_parse_eval_data
(
self
,
data
):
"""Parses data for training and evaluation."""
"""Parses data for training and evaluation."""
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
image
,
label
=
self
.
_prepare_image_and_label
(
data
)
# Binarize mask if groundtruth is a matting map
if
self
.
_gt_is_matting_map
:
label
=
tf
.
divide
(
tf
.
cast
(
label
,
dtype
=
tf
.
float32
),
255.0
)
label
=
utils
.
binarize_matting_map
(
label
)
# The label is first offset by +1 and then padded with 0.
# The label is first offset by +1 and then padded with 0.
label
+=
1
label
+=
1
label
=
tf
.
expand_dims
(
label
,
axis
=
3
)
label
=
tf
.
expand_dims
(
label
,
axis
=
3
)
...
@@ -196,13 +220,13 @@ class Parser(parser.Parser):
...
@@ -196,13 +220,13 @@ class Parser(parser.Parser):
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
,
image_scale
,
label
=
preprocess_ops
.
resize_and_crop_masks
(
label
,
image_scale
,
self
.
_output_size
,
offset
)
self
.
_output_size
,
offset
)
else
:
else
:
label
=
tf
.
image
.
pad_to_bounding_box
(
label
=
tf
.
image
.
pad_to_bounding_box
(
label
,
0
,
0
,
label
,
0
,
0
,
self
.
_groundtruth_padded_size
[
0
],
self
.
_groundtruth_padded_size
[
0
],
self
.
_groundtruth_padded_size
[
1
])
self
.
_groundtruth_padded_size
[
1
])
label
-=
1
label
-=
1
label
=
tf
.
where
(
tf
.
equal
(
label
,
-
1
),
label
=
tf
.
where
(
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
tf
.
equal
(
label
,
-
1
),
self
.
_ignore_label
*
tf
.
ones_like
(
label
),
label
)
label
=
tf
.
squeeze
(
label
,
axis
=
0
)
label
=
tf
.
squeeze
(
label
,
axis
=
0
)
valid_mask
=
tf
.
not_equal
(
label
,
self
.
_ignore_label
)
valid_mask
=
tf
.
not_equal
(
label
,
self
.
_ignore_label
)
...
...
official/vision/dataloaders/utils.py
View file @
69231ce9
...
@@ -67,3 +67,20 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor],
...
@@ -67,3 +67,20 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor],
groundtruths
[
'attributes'
][
k
]
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
groundtruths
[
'attributes'
][
k
]
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
v
,
size
,
-
1
)
v
,
size
,
-
1
)
return
groundtruths
return
groundtruths
def
binarize_matting_map
(
matting_map
:
tf
.
Tensor
,
threshold
:
float
=
0.5
)
->
tf
.
Tensor
:
"""Binarizes a matting map.
If the matting_map value is above a threshold, set it as 1 otherwise 0. The
binarization is done for every element in the matting_map.
Args:
matting_map: The groundtruth in the matting map format.
threshold: The threshold used to binarize the matting map.
Returns:
The binarized labels (0 for BG, 1 for FG) as tf.float32.
"""
return
tf
.
cast
(
tf
.
greater
(
matting_map
,
threshold
),
tf
.
float32
)
official/vision/losses/segmentation_losses.py
View file @
69231ce9
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.dataloaders
import
utils
EPSILON
=
1e-5
EPSILON
=
1e-5
...
@@ -28,6 +29,7 @@ class SegmentationLoss:
...
@@ -28,6 +29,7 @@ class SegmentationLoss:
label_smoothing
,
label_smoothing
,
class_weights
,
class_weights
,
ignore_label
,
ignore_label
,
gt_is_matting_map
,
use_groundtruth_dimension
,
use_groundtruth_dimension
,
top_k_percent_pixels
=
1.0
):
top_k_percent_pixels
=
1.0
):
"""Initializes `SegmentationLoss`.
"""Initializes `SegmentationLoss`.
...
@@ -37,6 +39,8 @@ class SegmentationLoss:
...
@@ -37,6 +39,8 @@ class SegmentationLoss:
spreading the amount of probability to all other label classes.
spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class.
class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label.
ignore_label: An integer specifying the ignore label.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
use_groundtruth_dimension: A boolean, whether to resize the output to
use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth.
match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
...
@@ -46,6 +50,7 @@ class SegmentationLoss:
...
@@ -46,6 +50,7 @@ class SegmentationLoss:
self
.
_label_smoothing
=
label_smoothing
self
.
_label_smoothing
=
label_smoothing
self
.
_class_weights
=
class_weights
self
.
_class_weights
=
class_weights
self
.
_ignore_label
=
ignore_label
self
.
_ignore_label
=
ignore_label
self
.
_gt_is_matting_map
=
gt_is_matting_map
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_use_groundtruth_dimension
=
use_groundtruth_dimension
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
self
.
_top_k_percent_pixels
=
top_k_percent_pixels
...
@@ -73,8 +78,12 @@ class SegmentationLoss:
...
@@ -73,8 +78,12 @@ class SegmentationLoss:
labels
,
(
height
,
width
),
labels
,
(
height
,
width
),
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
# Do not need to cast into int32 if it is a matting map
if
not
self
.
_gt_is_matting_map
:
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
labels
=
tf
.
cast
(
labels
,
tf
.
int32
)
valid_mask
=
tf
.
not_equal
(
labels
,
self
.
_ignore_label
)
valid_mask
=
tf
.
not_equal
(
labels
,
self
.
_ignore_label
)
cross_entropy_loss
=
self
.
compute_pixelwise_loss
(
labels
,
logits
,
valid_mask
,
cross_entropy_loss
=
self
.
compute_pixelwise_loss
(
labels
,
logits
,
valid_mask
,
**
kwargs
)
**
kwargs
)
...
@@ -119,6 +128,12 @@ class SegmentationLoss:
...
@@ -119,6 +128,12 @@ class SegmentationLoss:
'Length of class_weights should be {}'
.
format
(
num_classes
))
'Length of class_weights should be {}'
.
format
(
num_classes
))
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=-
1
)
valid_mask
=
tf
.
squeeze
(
tf
.
cast
(
valid_mask
,
tf
.
float32
),
axis
=-
1
)
# If groundtruth is matting map, binarize the value to create the weight
# mask
if
self
.
_gt_is_matting_map
:
labels
=
tf
.
cast
(
utils
.
binarize_matting_map
(
labels
),
tf
.
int32
)
weight_mask
=
tf
.
einsum
(
weight_mask
=
tf
.
einsum
(
'...y,y->...'
,
'...y,y->...'
,
tf
.
one_hot
(
tf
.
squeeze
(
labels
,
axis
=-
1
),
num_classes
,
dtype
=
tf
.
float32
),
tf
.
one_hot
(
tf
.
squeeze
(
labels
,
axis
=-
1
),
num_classes
,
dtype
=
tf
.
float32
),
...
@@ -131,8 +146,9 @@ class SegmentationLoss:
...
@@ -131,8 +146,9 @@ class SegmentationLoss:
This method can be overridden in subclasses for customizing loss function.
This method can be overridden in subclasses for customizing loss function.
Args:
Args:
labels: An int32 tensor in shape (batch_size, height, width, 1), which is
labels: If groundtruth mask is not matting map, an int32 tensor which is
the label map of the ground truth.
the label map of the groundtruth. If groundtruth mask is matting map,
an float32 tensor. The shape is always (batch_size, height, width, 1).
logits: A float tensor in shape (batch_size, height, width, num_classes)
logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network.
which is the output of the network.
**unused_kwargs: Unused keyword arguments.
**unused_kwargs: Unused keyword arguments.
...
@@ -140,10 +156,14 @@ class SegmentationLoss:
...
@@ -140,10 +156,14 @@ class SegmentationLoss:
Returns:
Returns:
A float tensor in shape (batch_size, height, width, num_classes).
A float tensor in shape (batch_size, height, width, num_classes).
"""
"""
labels
=
tf
.
squeeze
(
labels
,
axis
=-
1
)
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
num_classes
=
logits
.
get_shape
().
as_list
()[
-
1
]
onehot_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
return
onehot_labels
*
(
if
self
.
_gt_is_matting_map
:
train_labels
=
tf
.
concat
([
1
-
labels
,
labels
],
axis
=-
1
)
else
:
labels
=
tf
.
squeeze
(
labels
,
axis
=-
1
)
train_labels
=
tf
.
one_hot
(
labels
,
num_classes
)
return
train_labels
*
(
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
1
-
self
.
_label_smoothing
)
+
self
.
_label_smoothing
/
num_classes
def
aggregate_loss
(
self
,
pixelwise_loss
,
valid_mask
):
def
aggregate_loss
(
self
,
pixelwise_loss
,
valid_mask
):
...
...
official/vision/tasks/semantic_segmentation.py
View file @
69231ce9
...
@@ -35,15 +35,16 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -35,15 +35,16 @@ class SemanticSegmentationTask(base_task.Task):
def
build_model
(
self
):
def
build_model
(
self
):
"""Builds segmentation model."""
"""Builds segmentation model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_regularizer
=
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_segmentation_model
(
model
=
factory
.
build_segmentation_model
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
...
@@ -85,6 +86,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -85,6 +86,7 @@ class SemanticSegmentationTask(base_task.Task):
"""Builds classification input."""
"""Builds classification input."""
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
gt_is_matting_map
=
self
.
task_config
.
losses
.
gt_is_matting_map
if
params
.
tfds_name
:
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
...
@@ -96,6 +98,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -96,6 +98,7 @@ class SemanticSegmentationTask(base_task.Task):
crop_size
=
params
.
crop_size
,
crop_size
=
params
.
crop_size
,
ignore_label
=
ignore_label
,
ignore_label
=
ignore_label
,
resize_eval_groundtruth
=
params
.
resize_eval_groundtruth
,
resize_eval_groundtruth
=
params
.
resize_eval_groundtruth
,
gt_is_matting_map
=
gt_is_matting_map
,
groundtruth_padded_size
=
params
.
groundtruth_padded_size
,
groundtruth_padded_size
=
params
.
groundtruth_padded_size
,
aug_scale_min
=
params
.
aug_scale_min
,
aug_scale_min
=
params
.
aug_scale_min
,
aug_scale_max
=
params
.
aug_scale_max
,
aug_scale_max
=
params
.
aug_scale_max
,
...
@@ -132,6 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -132,6 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
loss_params
.
label_smoothing
,
loss_params
.
label_smoothing
,
loss_params
.
class_weights
,
loss_params
.
class_weights
,
loss_params
.
ignore_label
,
loss_params
.
ignore_label
,
loss_params
.
gt_is_matting_map
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
)
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
)
...
@@ -140,8 +144,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -140,8 +144,7 @@ class SemanticSegmentationTask(base_task.Task):
if
'mask_scores'
in
model_outputs
:
if
'mask_scores'
in
model_outputs
:
mask_scoring_loss_fn
=
segmentation_losses
.
MaskScoringLoss
(
mask_scoring_loss_fn
=
segmentation_losses
.
MaskScoringLoss
(
loss_params
.
ignore_label
)
loss_params
.
ignore_label
)
total_loss
+=
mask_scoring_loss_fn
(
total_loss
+=
mask_scoring_loss_fn
(
model_outputs
[
'mask_scores'
],
model_outputs
[
'mask_scores'
],
model_outputs
[
'logits'
],
model_outputs
[
'logits'
],
labels
[
'masks'
])
labels
[
'masks'
])
...
@@ -178,7 +181,8 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -178,7 +181,8 @@ class SemanticSegmentationTask(base_task.Task):
"""Gets streaming metrics for training/validation."""
"""Gets streaming metrics for training/validation."""
metrics
=
[]
metrics
=
[]
if
training
and
self
.
task_config
.
evaluation
.
report_train_mean_iou
:
if
training
and
self
.
task_config
.
evaluation
.
report_train_mean_iou
:
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
name
=
'mean_iou'
,
name
=
'mean_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
False
,
rescale_predictions
=
False
,
...
@@ -202,8 +206,8 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -202,8 +206,8 @@ class SemanticSegmentationTask(base_task.Task):
tf
.
keras
.
metrics
.
MeanSquaredError
(
name
=
'mask_scores_mse'
))
tf
.
keras
.
metrics
.
MeanSquaredError
(
name
=
'mask_scores_mse'
))
# Update state on CPU if TPUStrategy due to dynamic resizing.
# Update state on CPU if TPUStrategy due to dynamic resizing.
self
.
_process_iou_metric_on_cpu
=
isinstance
(
self
.
_process_iou_metric_on_cpu
=
isinstance
(
tf
.
distribute
.
get_strategy
(),
tf
.
distribute
.
get_strategy
(),
tf
.
distribute
.
TPUStrategy
)
tf
.
distribute
.
TPUStrategy
)
return
metrics
return
metrics
...
@@ -238,8 +242,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -238,8 +242,7 @@ class SemanticSegmentationTask(base_task.Task):
outputs
=
{
'logits'
:
outputs
}
outputs
=
{
'logits'
:
outputs
}
# Casting output layer as float32 is necessary when mixed_precision is
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
# Computes per-replica loss.
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
...
@@ -296,8 +299,8 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -296,8 +299,8 @@ class SemanticSegmentationTask(base_task.Task):
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
if
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
:
if
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
:
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
loss
=
self
.
build_losses
(
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
else
:
else
:
loss
=
0
loss
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment