Commit f31423a8 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416657609
parent 03a9dc97
...@@ -47,10 +47,10 @@ class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase): ...@@ -47,10 +47,10 @@ class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
model = segmentation_model.SegmentationModel( model = segmentation_model.SegmentationModel(
backbone=backbone, decoder=decoder, head=head) backbone=backbone, decoder=decoder, head=head)
logits = model(inputs) outputs = model(inputs)
self.assertAllEqual( self.assertAllEqual(
[2, input_size[0], input_size[0], input_size[1], num_classes], [2, input_size[0], input_size[0], input_size[1], num_classes],
logits.numpy().shape) outputs['logits'].numpy().shape)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized.""" """Validate the network can be serialized and deserialized."""
......
...@@ -56,4 +56,4 @@ class SegmentationModule(export_base.ExportModule): ...@@ -56,4 +56,4 @@ class SegmentationModule(export_base.ExportModule):
outputs = self.inference_step(images) outputs = self.inference_step(images)
output_key = 'logits' if self.params.task.model.head.output_logits else 'probs' output_key = 'logits' if self.params.task.model.head.output_logits else 'probs'
return {output_key: outputs} return {output_key: outputs['logits']}
...@@ -104,7 +104,8 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -104,7 +104,8 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
# outputs equal. # outputs equal.
expected_output = module.model(image_tensor, training=False) expected_output = module.model(image_tensor, training=False)
out = segmentation_fn(tf.constant(images)) out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['logits'].numpy(), expected_output.numpy()) self.assertAllClose(out['logits'].numpy(),
expected_output['logits'].numpy())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -198,6 +198,8 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -198,6 +198,8 @@ class SemanticSegmentation3DTask(base_task.Task):
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
outputs = outputs['logits']
if self.task_config.model.head.output_logits: if self.task_config.model.head.output_logits:
outputs = tf.nn.softmax(outputs) outputs = tf.nn.softmax(outputs)
...@@ -258,6 +260,7 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -258,6 +260,7 @@ class SemanticSegmentation3DTask(base_task.Task):
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
outputs = outputs['logits']
if self.task_config.model.head.output_logits: if self.task_config.model.head.output_logits:
outputs = tf.nn.softmax(outputs) outputs = tf.nn.softmax(outputs)
...@@ -268,8 +271,8 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -268,8 +271,8 @@ class SemanticSegmentation3DTask(base_task.Task):
# Compute dice score metrics on CPU. # Compute dice score metrics on CPU.
for metric in self.metrics: for metric in self.metrics:
labels = tf.cast(labels, tf.float32) labels = tf.cast(labels, tf.float32)
outputs = tf.cast(outputs, tf.float32) logits = tf.cast(outputs, tf.float32)
logs.update({metric.name: (labels, outputs)}) logs.update({metric.name: (labels, logits)})
return logs return logs
......
...@@ -75,6 +75,16 @@ class SegmentationHead(hyperparams.Config): ...@@ -75,6 +75,16 @@ class SegmentationHead(hyperparams.Config):
decoder_max_level: Optional[Union[int, str]] = None decoder_max_level: Optional[Union[int, str]] = None
@dataclasses.dataclass
class MaskScoringHead(hyperparams.Config):
"""Mask Scoring head config."""
num_convs: int = 4
num_filters: int = 128
fc_input_size: List[int] = dataclasses.field(default_factory=list)
num_fcs: int = 2
fc_dims: int = 1024
@dataclasses.dataclass @dataclasses.dataclass
class SemanticSegmentationModel(hyperparams.Config): class SemanticSegmentationModel(hyperparams.Config):
"""Semantic segmentation model config.""" """Semantic segmentation model config."""
...@@ -86,6 +96,7 @@ class SemanticSegmentationModel(hyperparams.Config): ...@@ -86,6 +96,7 @@ class SemanticSegmentationModel(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet()) type='resnet', resnet=backbones.ResNet())
decoder: decoders.Decoder = decoders.Decoder(type='identity') decoder: decoders.Decoder = decoders.Decoder(type='identity')
mask_scoring_head: Optional[MaskScoringHead] = None
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation()
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
EPSILON = 1e-5 EPSILON = 1e-5
...@@ -87,3 +89,46 @@ class SegmentationLoss: ...@@ -87,3 +89,46 @@ class SegmentationLoss:
loss = tf.reduce_sum(top_k_losses) / normalizer loss = tf.reduce_sum(top_k_losses) / normalizer
return loss return loss
def get_actual_mask_scores(logits, labels, ignore_label):
"""Gets actual mask scores."""
_, height, width, num_classes = logits.get_shape().as_list()
batch_size = tf.shape(logits)[0]
logits = tf.stop_gradient(logits)
labels = tf.image.resize(
labels, (height, width),
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
predicted_labels = tf.argmax(logits, -1, output_type=tf.int32)
flat_predictions = tf.reshape(predicted_labels, [batch_size, -1])
flat_labels = tf.cast(tf.reshape(labels, [batch_size, -1]), tf.int32)
one_hot_predictions = tf.one_hot(
flat_predictions, num_classes, on_value=True, off_value=False)
one_hot_labels = tf.one_hot(
flat_labels, num_classes, on_value=True, off_value=False)
keep_mask = tf.not_equal(flat_labels, ignore_label)
keep_mask = tf.expand_dims(keep_mask, 2)
overlap = tf.logical_and(one_hot_predictions, one_hot_labels)
overlap = tf.logical_and(overlap, keep_mask)
overlap = tf.reduce_sum(tf.cast(overlap, tf.float32), axis=1)
union = tf.logical_or(one_hot_predictions, one_hot_labels)
union = tf.logical_and(union, keep_mask)
union = tf.reduce_sum(tf.cast(union, tf.float32), axis=1)
actual_scores = tf.divide(overlap, tf.maximum(union, EPSILON))
return actual_scores
class MaskScoringLoss:
"""Mask Scoring loss."""
def __init__(self, ignore_label):
self._ignore_label = ignore_label
self._mse_loss = tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.NONE)
def __call__(self, predicted_scores, logits, labels):
actual_scores = get_actual_mask_scores(logits, labels, self._ignore_label)
loss = tf_utils.safe_mean(self._mse_loss(actual_scores, predicted_scores))
return loss
...@@ -369,5 +369,17 @@ def build_segmentation_model( ...@@ -369,5 +369,17 @@ def build_segmentation_model(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
model = segmentation_model.SegmentationModel(backbone, decoder, head) mask_scoring_head = None
if model_config.mask_scoring_head:
mask_scoring_head = segmentation_heads.MaskScoring(
num_classes=model_config.num_classes,
**model_config.mask_scoring_head.as_dict(),
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = segmentation_model.SegmentationModel(
backbone, decoder, head, mask_scoring_head=mask_scoring_head)
return model return model
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of segmentation heads.""" """Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping, Tuple from typing import List, Union, Optional, Mapping, Tuple, Any
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -21,6 +21,176 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -21,6 +21,176 @@ from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
class MaskScoring(tf.keras.Model):
"""Creates a mask scoring layer.
This implements mask scoring layer from the paper:
Zhaojin Huang, Lichao Huang, Yongchao Gong, Chang Huang, Xinggang Wang.
Mask Scoring R-CNN.
(https://arxiv.org/pdf/1903.00241.pdf)
"""
def __init__(
self,
num_classes: int,
fc_input_size: List[int],
num_convs: int = 3,
num_filters: int = 256,
fc_dims: int = 1024,
num_fcs: int = 2,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes mask scoring layer.
Args:
num_classes: An `int` for number of classes.
fc_input_size: A List of `int` for the input size of the
fully connected layers.
num_convs: An`int` for number of conv layers.
num_filters: An `int` for the number of filters for conv layers.
fc_dims: An `int` number of filters for each fully connected layers.
num_fcs: An `int` for number of fully connected layers.
activation: A `str` name of the activation function.
use_sync_bn: A bool, whether or not to use sync batch normalization.
norm_momentum: A float for the momentum in BatchNorm. Defaults to 0.99.
norm_epsilon: A float for the epsilon value in BatchNorm. Defaults to
0.001.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
**kwargs: Additional keyword arguments to be passed.
"""
super(MaskScoring, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'num_convs': num_convs,
'num_filters': num_filters,
'fc_input_size': fc_input_size,
'fc_dims': fc_dims,
'num_fcs': num_fcs,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'activation': activation,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
}
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the mask scoring head."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'filters': self._config_dict['num_filters'],
'kernel_size': 3,
'padding': 'same',
}
conv_kwargs.update({
'kernel_initializer': tf.keras.initializers.VarianceScaling(
scale=2, mode='fan_out', distribution='untruncated_normal'),
'bias_initializer': tf.zeros_initializer(),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
'bias_regularizer': self._config_dict['bias_regularizer'],
})
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn']
else tf.keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
self._convs = []
self._conv_norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'mask-scoring_{}'.format(i)
self._convs.append(conv_op(name=conv_name, **conv_kwargs))
bn_name = 'mask-scoring-bn_{}'.format(i)
self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._fcs = []
self._fc_norms = []
for i in range(self._config_dict['num_fcs']):
fc_name = 'mask-scoring-fc_{}'.format(i)
self._fcs.append(
tf.keras.layers.Dense(
units=self._config_dict['fc_dims'],
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1 / 3.0, mode='fan_out', distribution='uniform'),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name=fc_name))
bn_name = 'mask-scoring-fc-bn_{}'.format(i)
self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))
self._classifier = tf.keras.layers.Dense(
units=self._config_dict['num_classes'],
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
bias_initializer=tf.zeros_initializer(),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'],
name='iou-scores')
super(MaskScoring, self).build(input_shape)
def call(self, inputs: tf.Tensor, training: bool = None):
"""Forward pass mask scoring head.
Args:
inputs: A `tf.Tensor` of the shape [batch_size, width, size, num_classes],
representing the segmentation logits.
training: a `bool` indicating whether it is in `training` mode.
Returns:
mask_scores: A `tf.Tensor` of predicted mask scores
[batch_size, num_classes].
"""
x = tf.stop_gradient(inputs)
for conv, bn in zip(self._convs, self._conv_norms):
x = conv(x)
x = bn(x)
x = self._activation(x)
# Casts feat to float32 so the resize op can be run on TPU.
x = tf.cast(x, tf.float32)
x = tf.image.resize(x, size=self._config_dict['fc_input_size'],
method=tf.image.ResizeMethod.BILINEAR)
# Casts it back to be compatible with the rest opetations.
x = tf.cast(x, inputs.dtype)
_, h, w, filters = x.get_shape().as_list()
x = tf.reshape(x, [-1, h * w * filters])
for fc, bn in zip(self._fcs, self._fc_norms):
x = fc(x)
x = bn(x)
x = self._activation(x)
ious = self._classifier(x)
return ious
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SegmentationHead(tf.keras.layers.Layer): class SegmentationHead(tf.keras.layers.Layer):
"""Creates a segmentation head.""" """Creates a segmentation head."""
...@@ -225,6 +395,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -225,6 +395,7 @@ class SegmentationHead(tf.keras.layers.Layer):
segmentation prediction mask: A `tf.Tensor` of the segmentation mask segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features. scores predicted from input features.
""" """
backbone_output = inputs[0] backbone_output = inputs[0]
decoder_output = inputs[1] decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus': if self._config_dict['feature_fusion'] == 'deeplabv3plus':
......
...@@ -72,5 +72,36 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -72,5 +72,36 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
new_head = segmentation_heads.SegmentationHead.from_config(config) new_head = segmentation_heads.SegmentationHead.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config()) self.assertAllEqual(head.get_config(), new_head.get_config())
class MaskScoringHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 1, 64, [4, 4]),
(2, 1, 64, [4, 4]),
(3, 1, 64, [4, 4]),
(1, 2, 32, [8, 8]),
(2, 2, 32, [8, 8]),
(3, 2, 32, [8, 8]),)
def test_forward(self, num_convs, num_fcs,
num_filters, fc_input_size):
features = np.random.rand(2, 64, 64, 16)
head = segmentation_heads.MaskScoring(
num_classes=2,
num_convs=num_convs,
num_filters=num_filters,
fc_dims=128,
fc_input_size=fc_input_size)
scores = head(features)
self.assertAllEqual(scores.numpy().shape, [2, 2])
def test_serialize_deserialize(self):
head = segmentation_heads.MaskScoring(
num_classes=2, fc_input_size=[4, 4], fc_dims=128)
config = head.get_config()
new_head = segmentation_heads.MaskScoring.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Build segmentation models.""" """Build segmentation models."""
from typing import Any, Mapping, Union from typing import Any, Mapping, Union, Optional, Dict
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -35,13 +35,16 @@ class SegmentationModel(tf.keras.Model): ...@@ -35,13 +35,16 @@ class SegmentationModel(tf.keras.Model):
""" """
def __init__(self, backbone: tf.keras.Model, decoder: tf.keras.Model, def __init__(self, backbone: tf.keras.Model, decoder: tf.keras.Model,
head: tf.keras.layers.Layer, **kwargs): head: tf.keras.layers.Layer,
mask_scoring_head: Optional[tf.keras.layers.Layer] = None,
**kwargs):
"""Segmentation initialization function. """Segmentation initialization function.
Args: Args:
backbone: a backbone network. backbone: a backbone network.
decoder: a decoder network. E.g. FPN. decoder: a decoder network. E.g. FPN.
head: segmentation head. head: segmentation head.
mask_scoring_head: mask scoring head.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
super(SegmentationModel, self).__init__(**kwargs) super(SegmentationModel, self).__init__(**kwargs)
...@@ -49,12 +52,15 @@ class SegmentationModel(tf.keras.Model): ...@@ -49,12 +52,15 @@ class SegmentationModel(tf.keras.Model):
'backbone': backbone, 'backbone': backbone,
'decoder': decoder, 'decoder': decoder,
'head': head, 'head': head,
'mask_scoring_head': mask_scoring_head,
} }
self.backbone = backbone self.backbone = backbone
self.decoder = decoder self.decoder = decoder
self.head = head self.head = head
self.mask_scoring_head = mask_scoring_head
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor: def call(self, inputs: tf.Tensor, training: bool = None
) -> Dict[str, tf.Tensor]:
backbone_features = self.backbone(inputs) backbone_features = self.backbone(inputs)
if self.decoder: if self.decoder:
...@@ -62,7 +68,12 @@ class SegmentationModel(tf.keras.Model): ...@@ -62,7 +68,12 @@ class SegmentationModel(tf.keras.Model):
else: else:
decoder_features = backbone_features decoder_features = backbone_features
return self.head((backbone_features, decoder_features)) logits = self.head((backbone_features, decoder_features))
outputs = {'logits': logits}
if self.mask_scoring_head:
mask_scores = self.mask_scoring_head(logits)
outputs.update({'mask_scores': mask_scores})
return outputs
@property @property
def checkpoint_items( def checkpoint_items(
...@@ -71,6 +82,8 @@ class SegmentationModel(tf.keras.Model): ...@@ -71,6 +82,8 @@ class SegmentationModel(tf.keras.Model):
items = dict(backbone=self.backbone, head=self.head) items = dict(backbone=self.backbone, head=self.head)
if self.decoder is not None: if self.decoder is not None:
items.update(decoder=self.decoder) items.update(decoder=self.decoder)
if self.mask_scoring_head is not None:
items.update(mask_scoring_head=self.mask_scoring_head)
return items return items
def get_config(self) -> Mapping[str, Any]: def get_config(self) -> Mapping[str, Any]:
......
...@@ -50,13 +50,14 @@ class SegmentationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -50,13 +50,14 @@ class SegmentationNetworkTest(parameterized.TestCase, tf.test.TestCase):
model = segmentation_model.SegmentationModel( model = segmentation_model.SegmentationModel(
backbone=backbone, backbone=backbone,
decoder=decoder, decoder=decoder,
head=head head=head,
mask_scoring_head=None,
) )
logits = model(inputs) outputs = model(inputs)
self.assertAllEqual( self.assertAllEqual(
[2, input_size // (2**level), input_size // (2**level), num_classes], [2, input_size // (2**level), input_size // (2**level), num_classes],
logits.numpy().shape) outputs['logits'].numpy().shape)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized.""" """Validate the network can be serialized and deserialized."""
......
...@@ -77,7 +77,8 @@ class SegmentationModule(export_base.ExportModule): ...@@ -77,7 +77,8 @@ class SegmentationModule(export_base.ExportModule):
shape=self._input_image_size + [3], dtype=tf.float32), shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32)) parallel_iterations=32))
masks = self.inference_step(images) outputs = self.inference_step(images)
masks = tf.image.resize(masks, self._input_image_size, method='bilinear') outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
return dict(predicted_masks=masks) return outputs
...@@ -103,10 +103,10 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -103,10 +103,10 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
else: else:
processed_images = images processed_images = images
expected_output = tf.image.resize( expected_output = tf.image.resize(
module.model(processed_images, training=False), [112, 112], module.model(processed_images, training=False)['logits'], [112, 112],
method='bilinear') method='bilinear')
out = segmentation_fn(tf.constant(images)) out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['predicted_masks'].numpy(), expected_output.numpy()) self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -135,7 +135,15 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -135,7 +135,15 @@ class SemanticSegmentationTask(base_task.Task):
use_groundtruth_dimension=loss_params.use_groundtruth_dimension, use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels) top_k_percent_pixels=loss_params.top_k_percent_pixels)
total_loss = segmentation_loss_fn(model_outputs, labels['masks']) total_loss = segmentation_loss_fn(model_outputs['logits'], labels['masks'])
if 'mask_scores' in model_outputs:
mask_scoring_loss_fn = segmentation_losses.MaskScoringLoss(
loss_params.ignore_label)
total_loss += mask_scoring_loss_fn(
model_outputs['mask_scores'],
model_outputs['logits'],
labels['masks'])
if aux_losses: if aux_losses:
total_loss += tf.add_n(aux_losses) total_loss += tf.add_n(aux_losses)
...@@ -144,6 +152,28 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -144,6 +152,28 @@ class SemanticSegmentationTask(base_task.Task):
return total_loss return total_loss
def process_metrics(self, metrics, labels, model_outputs, **kwargs):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
**kwargs: other args.
"""
for metric in metrics:
if 'mask_scores_mse' is metric.name:
actual_mask_scores = segmentation_losses.get_actual_mask_scores(
model_outputs['logits'], labels['masks'],
self.task_config.losses.ignore_label)
metric.update_state(actual_mask_scores, model_outputs['mask_scores'])
else:
metric.update_state(labels, model_outputs['logits'])
def build_metrics(self, training: bool = True): def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
metrics = [] metrics = []
...@@ -153,6 +183,9 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -153,6 +183,9 @@ class SemanticSegmentationTask(base_task.Task):
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
rescale_predictions=False, rescale_predictions=False,
dtype=tf.float32)) dtype=tf.float32))
if self.task_config.model.mask_scoring_head:
metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
else: else:
self.iou_metric = segmentation_metrics.PerClassIoU( self.iou_metric = segmentation_metrics.PerClassIoU(
name='per_class_iou', name='per_class_iou',
...@@ -160,6 +193,11 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -160,6 +193,11 @@ class SemanticSegmentationTask(base_task.Task):
rescale_predictions=not self.task_config.validation_data rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth, .resize_eval_groundtruth,
dtype=tf.float32) dtype=tf.float32)
if self.task_config.validation_data.resize_eval_groundtruth and self.task_config.model.mask_scoring_head:
# Masks scores metric can only be computed if labels are scaled to match
# preticted mask scores.
metrics.append(
tf.keras.metrics.MeanSquaredError(name='mask_scores_mse'))
# Update state on CPU if TPUStrategy due to dynamic resizing. # Update state on CPU if TPUStrategy due to dynamic resizing.
self._process_iou_metric_on_cpu = isinstance( self._process_iou_metric_on_cpu = isinstance(
...@@ -260,9 +298,9 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -260,9 +298,9 @@ class SemanticSegmentationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if self._process_iou_metric_on_cpu: if self._process_iou_metric_on_cpu:
logs.update({self.iou_metric.name: (labels, outputs)}) logs.update({self.iou_metric.name: (labels, outputs['logits'])})
else: else:
self.iou_metric.update_state(labels, outputs) self.iou_metric.update_state(labels, outputs['logits'])
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment