"...images/git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "ba78cbdcd0177db9d6ec0947ddf81b5157f3535c"
Commit 69231ce9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 476005369
parent 37e76715
...@@ -109,6 +109,7 @@ class Losses(maskrcnn.Losses): ...@@ -109,6 +109,7 @@ class Losses(maskrcnn.Losses):
"""Panoptic Mask R-CNN loss config.""" """Panoptic Mask R-CNN loss config."""
semantic_segmentation_label_smoothing: float = 0.0 semantic_segmentation_label_smoothing: float = 0.0
semantic_segmentation_ignore_label: int = 255 semantic_segmentation_ignore_label: int = 255
semantic_segmentation_gt_is_matting_map: bool = False
semantic_segmentation_class_weights: List[float] = dataclasses.field( semantic_segmentation_class_weights: List[float] = dataclasses.field(
default_factory=list) default_factory=list)
semantic_segmentation_use_groundtruth_dimension: bool = True semantic_segmentation_use_groundtruth_dimension: bool = True
......
...@@ -181,6 +181,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -181,6 +181,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
label_smoothing=params.semantic_segmentation_label_smoothing, label_smoothing=params.semantic_segmentation_label_smoothing,
class_weights=params.semantic_segmentation_class_weights, class_weights=params.semantic_segmentation_class_weights,
ignore_label=params.semantic_segmentation_ignore_label, ignore_label=params.semantic_segmentation_ignore_label,
gt_is_matting_map=params.semantic_segmentation_gt_is_matting_map,
use_groundtruth_dimension=use_groundtruth_dimension, use_groundtruth_dimension=use_groundtruth_dimension,
top_k_percent_pixels=params.semantic_segmentation_top_k_percent_pixels) top_k_percent_pixels=params.semantic_segmentation_top_k_percent_pixels)
......
...@@ -104,6 +104,7 @@ class Losses(hyperparams.Config): ...@@ -104,6 +104,7 @@ class Losses(hyperparams.Config):
loss_weight: float = 1.0 loss_weight: float = 1.0
label_smoothing: float = 0.0 label_smoothing: float = 0.0
ignore_label: int = 255 ignore_label: int = 255
gt_is_matting_map: bool = False
class_weights: List[float] = dataclasses.field(default_factory=list) class_weights: List[float] = dataclasses.field(default_factory=list)
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True use_groundtruth_dimension: bool = True
...@@ -132,8 +133,7 @@ class SemanticSegmentationTask(cfg.TaskConfig): ...@@ -132,8 +133,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
evaluation: Evaluation = Evaluation() evaluation: Evaluation = Evaluation()
train_input_partition_dims: List[int] = dataclasses.field( train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field( eval_input_partition_dims: List[int] = dataclasses.field(default_factory=list)
default_factory=list)
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[ init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder str, List[str]] = 'all' # all, backbone, and/or decoder
...@@ -151,6 +151,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig: ...@@ -151,6 +151,7 @@ def semantic_segmentation() -> cfg.ExperimentConfig:
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
# PASCAL VOC 2012 Dataset # PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES = 10582 PASCAL_TRAIN_EXAMPLES = 10582
PASCAL_VAL_EXAMPLES = 1449 PASCAL_VAL_EXAMPLES = 1449
...@@ -174,11 +175,15 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -174,11 +175,15 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
num_classes=21, num_classes=21,
input_size=[None, None, 3], input_size=[None, None, 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet',
model_id=101, output_stride=output_stride, dilated_resnet=backbones.DilatedResNet(
multigrid=multigrid, stem_type=stem_type)), model_id=101,
output_stride=output_stride,
multigrid=multigrid,
stem_type=stem_type)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', aspp=decoders.ASPP( type='aspp',
aspp=decoders.ASPP(
level=level, dilation_rates=aspp_dilation_rates)), level=level, dilation_rates=aspp_dilation_rates)),
head=SegmentationHead(level=level, num_convs=0), head=SegmentationHead(level=level, num_convs=0),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
...@@ -262,9 +267,12 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: ...@@ -262,9 +267,12 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
num_classes=21, num_classes=21,
input_size=[None, None, 3], input_size=[None, None, 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet',
model_id=101, output_stride=output_stride, dilated_resnet=backbones.DilatedResNet(
stem_type=stem_type, multigrid=multigrid)), model_id=101,
output_stride=output_stride,
stem_type=stem_type,
multigrid=multigrid)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', type='aspp',
aspp=decoders.ASPP( aspp=decoders.ASPP(
...@@ -356,8 +364,7 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: ...@@ -356,8 +364,7 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()), decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()),
head=SegmentationHead(level=3, num_convs=3), head=SegmentationHead(level=3, num_convs=3),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
activation='swish', activation='swish', use_sync_bn=True)),
use_sync_bn=True)),
losses=Losses(l2_weight_decay=1e-4), losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
...@@ -530,13 +537,17 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: ...@@ -530,13 +537,17 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
num_classes=19, num_classes=19,
input_size=[None, None, 3], input_size=[None, None, 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet',
model_id=101, output_stride=output_stride, dilated_resnet=backbones.DilatedResNet(
stem_type=stem_type, multigrid=multigrid)), model_id=101,
output_stride=output_stride,
stem_type=stem_type,
multigrid=multigrid)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', type='aspp',
aspp=decoders.ASPP( aspp=decoders.ASPP(
level=level, dilation_rates=aspp_dilation_rates, level=level,
dilation_rates=aspp_dilation_rates,
pool_kernel_size=[512, 1024])), pool_kernel_size=[512, 1024])),
head=SegmentationHead( head=SegmentationHead(
level=level, level=level,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tensorflow as tf import tensorflow as tf
from official.vision.dataloaders import decoder from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser from official.vision.dataloaders import parser
from official.vision.dataloaders import utils
from official.vision.ops import preprocess_ops from official.vision.ops import preprocess_ops
...@@ -25,26 +26,29 @@ class Decoder(decoder.Decoder): ...@@ -25,26 +26,29 @@ class Decoder(decoder.Decoder):
def __init__(self): def __init__(self):
self._keys_to_features = { self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''), 'image/encoded':
'image/height': tf.io.FixedLenFeature((), tf.int64, default_value=0), tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/width': tf.io.FixedLenFeature((), tf.int64, default_value=0), 'image/height':
tf.io.FixedLenFeature((), tf.int64, default_value=0),
'image/width':
tf.io.FixedLenFeature((), tf.int64, default_value=0),
'image/segmentation/class/encoded': 'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='') tf.io.FixedLenFeature((), tf.string, default_value='')
} }
def decode(self, serialized_example): def decode(self, serialized_example):
return tf.io.parse_single_example( return tf.io.parse_single_example(serialized_example,
serialized_example, self._keys_to_features) self._keys_to_features)
class Parser(parser.Parser): class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors. """Parser to parse an image and its annotations into a dictionary of tensors."""
"""
def __init__(self, def __init__(self,
output_size, output_size,
crop_size=None, crop_size=None,
resize_eval_groundtruth=True, resize_eval_groundtruth=True,
gt_is_matting_map=False,
groundtruth_padded_size=None, groundtruth_padded_size=None,
ignore_label=255, ignore_label=255,
aug_rand_hflip=False, aug_rand_hflip=False,
...@@ -63,13 +67,16 @@ class Parser(parser.Parser): ...@@ -63,13 +67,16 @@ class Parser(parser.Parser):
original image sizes. original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size. resized to output_size.
gt_is_matting_map: `bool`, if True, the expected mask is in the range
between 0 and 255. The parser will normalize the value of the mask into
the range between 0 and 1.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
resize_eval_groundtruth is set to False, the groundtruth masks are resize_eval_groundtruth is set to False, the groundtruth masks are
padded to this size. padded to this size.
ignore_label: `int` the pixel with ignore label will not used for training ignore_label: `int` the pixel with ignore label will not used for training
and evaluation. and evaluation.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random horizontal
horizontal flip. flip.
preserve_aspect_ratio: `bool`, if True, the aspect ratio is preserved, preserve_aspect_ratio: `bool`, if True, the aspect ratio is preserved,
otherwise, the image is resized to output_size. otherwise, the image is resized to output_size.
aug_scale_min: `float`, the minimum scale applied to `output_size` for aug_scale_min: `float`, the minimum scale applied to `output_size` for
...@@ -84,6 +91,7 @@ class Parser(parser.Parser): ...@@ -84,6 +91,7 @@ class Parser(parser.Parser):
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None): if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
raise ValueError('groundtruth_padded_size ([height, width]) needs to be' raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
'specified when resize_eval_groundtruth is False.') 'specified when resize_eval_groundtruth is False.')
self._gt_is_matting_map = gt_is_matting_map
self._groundtruth_padded_size = groundtruth_padded_size self._groundtruth_padded_size = groundtruth_padded_size
self._ignore_label = ignore_label self._ignore_label = ignore_label
self._preserve_aspect_ratio = preserve_aspect_ratio self._preserve_aspect_ratio = preserve_aspect_ratio
...@@ -99,8 +107,8 @@ class Parser(parser.Parser): ...@@ -99,8 +107,8 @@ class Parser(parser.Parser):
def _prepare_image_and_label(self, data): def _prepare_image_and_label(self, data):
"""Prepare normalized image and label.""" """Prepare normalized image and label."""
image = tf.io.decode_image(data['image/encoded'], channels=3) image = tf.io.decode_image(data['image/encoded'], channels=3)
label = tf.io.decode_image(data['image/segmentation/class/encoded'], label = tf.io.decode_image(
channels=1) data['image/segmentation/class/encoded'], channels=1)
height = data['image/height'] height = data['image/height']
width = data['image/width'] width = data['image/width']
image = tf.reshape(image, (height, width, 3)) image = tf.reshape(image, (height, width, 3))
...@@ -122,6 +130,16 @@ class Parser(parser.Parser): ...@@ -122,6 +130,16 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation.""" """Parses data for training and evaluation."""
image, label = self._prepare_image_and_label(data) image, label = self._prepare_image_and_label(data)
# Normalize the label into the range of 0 and 1 for matting groundtruth.
# Note that the input groundtruth labels must be 0 to 255, and do not
# contain ignore_label. For gt_is_matting_map case, ignore_label is only
# used for padding the labels.
if self._gt_is_matting_map:
scale = tf.constant(255.0, dtype=tf.float32)
scale = tf.expand_dims(scale, axis=0)
scale = tf.expand_dims(scale, axis=0)
label = tf.cast(label, tf.float32) / scale
if self._crop_size: if self._crop_size:
label = tf.reshape(label, [data['image/height'], data['image/width'], 1]) label = tf.reshape(label, [data['image/height'], data['image/width'], 1])
...@@ -132,8 +150,7 @@ class Parser(parser.Parser): ...@@ -132,8 +150,7 @@ class Parser(parser.Parser):
label = tf.image.resize(label, self._output_size, method='nearest') label = tf.image.resize(label, self._output_size, method='nearest')
image_mask = tf.concat([image, label], axis=2) image_mask = tf.concat([image, label], axis=2)
image_mask_crop = tf.image.random_crop(image_mask, image_mask_crop = tf.image.random_crop(image_mask, self._crop_size + [4])
self._crop_size + [4])
image = image_mask_crop[:, :, :-1] image = image_mask_crop[:, :, :-1]
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._crop_size) label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._crop_size)
...@@ -159,13 +176,14 @@ class Parser(parser.Parser): ...@@ -159,13 +176,14 @@ class Parser(parser.Parser):
# The label is first offset by +1 and then padded with 0. # The label is first offset by +1 and then padded with 0.
label += 1 label += 1
label = tf.expand_dims(label, axis=3) label = tf.expand_dims(label, axis=3)
label = preprocess_ops.resize_and_crop_masks( label = preprocess_ops.resize_and_crop_masks(label, image_scale,
label, image_scale, train_image_size, offset) train_image_size, offset)
label -= 1 label -= 1
label = tf.where(tf.equal(label, -1), label = tf.where(
self._ignore_label * tf.ones_like(label), label) tf.equal(label, -1), self._ignore_label * tf.ones_like(label), label)
label = tf.squeeze(label, axis=0) label = tf.squeeze(label, axis=0)
valid_mask = tf.not_equal(label, self._ignore_label) valid_mask = tf.not_equal(label, self._ignore_label)
labels = { labels = {
'masks': label, 'masks': label,
'valid_masks': valid_mask, 'valid_masks': valid_mask,
...@@ -180,6 +198,12 @@ class Parser(parser.Parser): ...@@ -180,6 +198,12 @@ class Parser(parser.Parser):
def _parse_eval_data(self, data): def _parse_eval_data(self, data):
"""Parses data for training and evaluation.""" """Parses data for training and evaluation."""
image, label = self._prepare_image_and_label(data) image, label = self._prepare_image_and_label(data)
# Binarize mask if groundtruth is a matting map
if self._gt_is_matting_map:
label = tf.divide(tf.cast(label, dtype=tf.float32), 255.0)
label = utils.binarize_matting_map(label)
# The label is first offset by +1 and then padded with 0. # The label is first offset by +1 and then padded with 0.
label += 1 label += 1
label = tf.expand_dims(label, axis=3) label = tf.expand_dims(label, axis=3)
...@@ -196,13 +220,13 @@ class Parser(parser.Parser): ...@@ -196,13 +220,13 @@ class Parser(parser.Parser):
label = preprocess_ops.resize_and_crop_masks(label, image_scale, label = preprocess_ops.resize_and_crop_masks(label, image_scale,
self._output_size, offset) self._output_size, offset)
else: else:
label = tf.image.pad_to_bounding_box( label = tf.image.pad_to_bounding_box(label, 0, 0,
label, 0, 0, self._groundtruth_padded_size[0], self._groundtruth_padded_size[0],
self._groundtruth_padded_size[1]) self._groundtruth_padded_size[1])
label -= 1 label -= 1
label = tf.where(tf.equal(label, -1), label = tf.where(
self._ignore_label * tf.ones_like(label), label) tf.equal(label, -1), self._ignore_label * tf.ones_like(label), label)
label = tf.squeeze(label, axis=0) label = tf.squeeze(label, axis=0)
valid_mask = tf.not_equal(label, self._ignore_label) valid_mask = tf.not_equal(label, self._ignore_label)
......
...@@ -67,3 +67,20 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor], ...@@ -67,3 +67,20 @@ def pad_groundtruths_to_fixed_size(groundtruths: Dict[str, tf.Tensor],
groundtruths['attributes'][k] = preprocess_ops.clip_or_pad_to_fixed_size( groundtruths['attributes'][k] = preprocess_ops.clip_or_pad_to_fixed_size(
v, size, -1) v, size, -1)
return groundtruths return groundtruths
def binarize_matting_map(matting_map: tf.Tensor,
threshold: float = 0.5) -> tf.Tensor:
"""Binarizes a matting map.
If the matting_map value is above a threshold, set it as 1 otherwise 0. The
binarization is done for every element in the matting_map.
Args:
matting_map: The groundtruth in the matting map format.
threshold: The threshold used to binarize the matting map.
Returns:
The binarized labels (0 for BG, 1 for FG) as tf.float32.
"""
return tf.cast(tf.greater(matting_map, threshold), tf.float32)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.dataloaders import utils
EPSILON = 1e-5 EPSILON = 1e-5
...@@ -28,6 +29,7 @@ class SegmentationLoss: ...@@ -28,6 +29,7 @@ class SegmentationLoss:
label_smoothing, label_smoothing,
class_weights, class_weights,
ignore_label, ignore_label,
gt_is_matting_map,
use_groundtruth_dimension, use_groundtruth_dimension,
top_k_percent_pixels=1.0): top_k_percent_pixels=1.0):
"""Initializes `SegmentationLoss`. """Initializes `SegmentationLoss`.
...@@ -37,6 +39,8 @@ class SegmentationLoss: ...@@ -37,6 +39,8 @@ class SegmentationLoss:
spreading the amount of probability to all other label classes. spreading the amount of probability to all other label classes.
class_weights: A float list containing the weight of each class. class_weights: A float list containing the weight of each class.
ignore_label: An integer specifying the ignore label. ignore_label: An integer specifying the ignore label.
gt_is_matting_map: If or not the groundtruth mask is a matting map. Note
that the matting map is only supported for 2 class segmentation.
use_groundtruth_dimension: A boolean, whether to resize the output to use_groundtruth_dimension: A boolean, whether to resize the output to
match the dimension of the ground truth. match the dimension of the ground truth.
top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its top_k_percent_pixels: A float, the value lies in [0.0, 1.0]. When its
...@@ -46,6 +50,7 @@ class SegmentationLoss: ...@@ -46,6 +50,7 @@ class SegmentationLoss:
self._label_smoothing = label_smoothing self._label_smoothing = label_smoothing
self._class_weights = class_weights self._class_weights = class_weights
self._ignore_label = ignore_label self._ignore_label = ignore_label
self._gt_is_matting_map = gt_is_matting_map
self._use_groundtruth_dimension = use_groundtruth_dimension self._use_groundtruth_dimension = use_groundtruth_dimension
self._top_k_percent_pixels = top_k_percent_pixels self._top_k_percent_pixels = top_k_percent_pixels
...@@ -73,8 +78,12 @@ class SegmentationLoss: ...@@ -73,8 +78,12 @@ class SegmentationLoss:
labels, (height, width), labels, (height, width),
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
labels = tf.cast(labels, tf.int32) # Do not need to cast into int32 if it is a matting map
if not self._gt_is_matting_map:
labels = tf.cast(labels, tf.int32)
valid_mask = tf.not_equal(labels, self._ignore_label) valid_mask = tf.not_equal(labels, self._ignore_label)
cross_entropy_loss = self.compute_pixelwise_loss(labels, logits, valid_mask, cross_entropy_loss = self.compute_pixelwise_loss(labels, logits, valid_mask,
**kwargs) **kwargs)
...@@ -119,6 +128,12 @@ class SegmentationLoss: ...@@ -119,6 +128,12 @@ class SegmentationLoss:
'Length of class_weights should be {}'.format(num_classes)) 'Length of class_weights should be {}'.format(num_classes))
valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=-1) valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=-1)
# If groundtruth is matting map, binarize the value to create the weight
# mask
if self._gt_is_matting_map:
labels = tf.cast(utils.binarize_matting_map(labels), tf.int32)
weight_mask = tf.einsum( weight_mask = tf.einsum(
'...y,y->...', '...y,y->...',
tf.one_hot(tf.squeeze(labels, axis=-1), num_classes, dtype=tf.float32), tf.one_hot(tf.squeeze(labels, axis=-1), num_classes, dtype=tf.float32),
...@@ -131,8 +146,9 @@ class SegmentationLoss: ...@@ -131,8 +146,9 @@ class SegmentationLoss:
This method can be overridden in subclasses for customizing loss function. This method can be overridden in subclasses for customizing loss function.
Args: Args:
labels: An int32 tensor in shape (batch_size, height, width, 1), which is labels: If groundtruth mask is not matting map, an int32 tensor which is
the label map of the ground truth. the label map of the groundtruth. If groundtruth mask is matting map,
an float32 tensor. The shape is always (batch_size, height, width, 1).
logits: A float tensor in shape (batch_size, height, width, num_classes) logits: A float tensor in shape (batch_size, height, width, num_classes)
which is the output of the network. which is the output of the network.
**unused_kwargs: Unused keyword arguments. **unused_kwargs: Unused keyword arguments.
...@@ -140,10 +156,14 @@ class SegmentationLoss: ...@@ -140,10 +156,14 @@ class SegmentationLoss:
Returns: Returns:
A float tensor in shape (batch_size, height, width, num_classes). A float tensor in shape (batch_size, height, width, num_classes).
""" """
labels = tf.squeeze(labels, axis=-1)
num_classes = logits.get_shape().as_list()[-1] num_classes = logits.get_shape().as_list()[-1]
onehot_labels = tf.one_hot(labels, num_classes)
return onehot_labels * ( if self._gt_is_matting_map:
train_labels = tf.concat([1 - labels, labels], axis=-1)
else:
labels = tf.squeeze(labels, axis=-1)
train_labels = tf.one_hot(labels, num_classes)
return train_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes 1 - self._label_smoothing) + self._label_smoothing / num_classes
def aggregate_loss(self, pixelwise_loss, valid_mask): def aggregate_loss(self, pixelwise_loss, valid_mask):
......
...@@ -35,15 +35,16 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -35,15 +35,16 @@ class SemanticSegmentationTask(base_task.Task):
def build_model(self): def build_model(self):
"""Builds segmentation model.""" """Builds segmentation model."""
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(shape=[None] +
shape=[None] + self.task_config.model.input_size) self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2( l2_regularizer = (
l2_weight_decay / 2.0) if l2_weight_decay else None) tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
model = factory.build_segmentation_model( model = factory.build_segmentation_model(
input_specs=input_specs, input_specs=input_specs,
...@@ -85,6 +86,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -85,6 +86,7 @@ class SemanticSegmentationTask(base_task.Task):
"""Builds classification input.""" """Builds classification input."""
ignore_label = self.task_config.losses.ignore_label ignore_label = self.task_config.losses.ignore_label
gt_is_matting_map = self.task_config.losses.gt_is_matting_map
if params.tfds_name: if params.tfds_name:
decoder = tfds_factory.get_segmentation_decoder(params.tfds_name) decoder = tfds_factory.get_segmentation_decoder(params.tfds_name)
...@@ -96,6 +98,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -96,6 +98,7 @@ class SemanticSegmentationTask(base_task.Task):
crop_size=params.crop_size, crop_size=params.crop_size,
ignore_label=ignore_label, ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth, resize_eval_groundtruth=params.resize_eval_groundtruth,
gt_is_matting_map=gt_is_matting_map,
groundtruth_padded_size=params.groundtruth_padded_size, groundtruth_padded_size=params.groundtruth_padded_size,
aug_scale_min=params.aug_scale_min, aug_scale_min=params.aug_scale_min,
aug_scale_max=params.aug_scale_max, aug_scale_max=params.aug_scale_max,
...@@ -132,6 +135,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -132,6 +135,7 @@ class SemanticSegmentationTask(base_task.Task):
loss_params.label_smoothing, loss_params.label_smoothing,
loss_params.class_weights, loss_params.class_weights,
loss_params.ignore_label, loss_params.ignore_label,
loss_params.gt_is_matting_map,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension, use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels) top_k_percent_pixels=loss_params.top_k_percent_pixels)
...@@ -140,10 +144,9 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -140,10 +144,9 @@ class SemanticSegmentationTask(base_task.Task):
if 'mask_scores' in model_outputs: if 'mask_scores' in model_outputs:
mask_scoring_loss_fn = segmentation_losses.MaskScoringLoss( mask_scoring_loss_fn = segmentation_losses.MaskScoringLoss(
loss_params.ignore_label) loss_params.ignore_label)
total_loss += mask_scoring_loss_fn( total_loss += mask_scoring_loss_fn(model_outputs['mask_scores'],
model_outputs['mask_scores'], model_outputs['logits'],
model_outputs['logits'], labels['masks'])
labels['masks'])
if aux_losses: if aux_losses:
total_loss += tf.add_n(aux_losses) total_loss += tf.add_n(aux_losses)
...@@ -178,11 +181,12 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -178,11 +181,12 @@ class SemanticSegmentationTask(base_task.Task):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
metrics = [] metrics = []
if training and self.task_config.evaluation.report_train_mean_iou: if training and self.task_config.evaluation.report_train_mean_iou:
metrics.append(segmentation_metrics.MeanIoU( metrics.append(
name='mean_iou', segmentation_metrics.MeanIoU(
num_classes=self.task_config.model.num_classes, name='mean_iou',
rescale_predictions=False, num_classes=self.task_config.model.num_classes,
dtype=tf.float32)) rescale_predictions=False,
dtype=tf.float32))
if self.task_config.model.get('mask_scoring_head'): if self.task_config.model.get('mask_scoring_head'):
metrics.append( metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse')) tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
...@@ -202,8 +206,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -202,8 +206,8 @@ class SemanticSegmentationTask(base_task.Task):
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse')) tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
# Update state on CPU if TPUStrategy due to dynamic resizing. # Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance( self._process_iou_metric_on_cpu = isinstance(tf.distribute.get_strategy(),
tf.distribute.get_strategy(), tf.distribute.TPUStrategy) tf.distribute.TPUStrategy)
return metrics return metrics
...@@ -238,8 +242,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -238,8 +242,7 @@ class SemanticSegmentationTask(base_task.Task):
outputs = {'logits': outputs} outputs = {'logits': outputs}
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure( outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
...@@ -296,8 +299,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -296,8 +299,8 @@ class SemanticSegmentationTask(base_task.Task):
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.validation_data.resize_eval_groundtruth: if self.task_config.validation_data.resize_eval_groundtruth:
loss = self.build_losses(model_outputs=outputs, labels=labels, loss = self.build_losses(
aux_losses=model.losses) model_outputs=outputs, labels=labels, aux_losses=model.losses)
else: else:
loss = 0 loss = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment