Commit 8bc64372 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343888044
parent 50efd367
......@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Backbones configurations."""
from typing import Optional
from typing import Optional, List
# Import libraries
import dataclasses
......@@ -36,6 +36,9 @@ class DilatedResNet(hyperparams.Config):
"""DilatedResNet config."""
model_id: int = 50
output_stride: int = 16
multigrid: Optional[List[int]] = None
stem_type: str = 'v0'
last_stage_repeats: int = 1
@dataclasses.dataclass
......
......@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0
num_filters: int = 256
pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
@dataclasses.dataclass
......
# Top1 accuracy 80.36%
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -10,6 +11,9 @@ task:
dilated_resnet:
model_id: 101
output_stride: 16
stem_type: 'v1'
multigrid: [1, 2, 4]
last_stage_repeats: 1
norm_activation:
activation: 'swish'
losses:
......
# Dilated ResNet-101 Pascal segmentation. 80.89 mean IOU.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 8
norm_activation:
activation: 'swish'
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 8
norm_activation:
activation: 'swish'
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
# Dilated ResNet-101 Pascal segmentation. 80.83 mean IOU with output stride of 16.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
......@@ -32,6 +32,8 @@ from official.vision.beta.configs import decoders
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
train_on_crops: bool = False
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
......@@ -42,6 +44,7 @@ class DataConfig(cfg.DataConfig):
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
drop_remainder: bool = True
......@@ -73,11 +76,12 @@ class SemanticSegmentationModel(hyperparams.Config):
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.1
label_smoothing: float = 0.0
ignore_label: int = 255
class_weights: List[float] = dataclasses.field(default_factory=list)
l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True
top_k_percent_pixels: float = 1.0
@dataclasses.dataclass
......@@ -115,18 +119,20 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
train_batch_size = 16
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 8
output_stride = 16
aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16
multigrid = [1, 2, 4]
stem_type = 'v1'
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
# TODO(arashwan): test changing size to 513 to match deeplab.
input_size=[512, 512, 3],
input_size=[None, None, 3],
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=output_stride)),
model_id=101, output_stride=output_stride,
multigrid=multigrid, stem_type=stem_type)),
decoder=decoders.Decoder(
type='aspp', aspp=decoders.ASPP(
level=level, dilation_rates=aspp_dilation_rates)),
......@@ -139,19 +145,22 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
# TODO(arashwan): test changing size to 513 to match deeplab.
output_size=[512, 512],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
output_size=[512, 512],
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512],
drop_remainder=False),
# resnet50
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400',
# resnet101
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
......@@ -199,16 +208,19 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 16
aspp_dilation_rates = [6, 12, 18] # [12, 24, 36] if output_stride = 8
aspp_dilation_rates = [6, 12, 18]
multigrid = [1, 2, 4]
stem_type = 'v1'
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
task=SemanticSegmentationTask(
model=SemanticSegmentationModel(
num_classes=21,
input_size=[512, 512, 3],
input_size=[None, None, 3],
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=output_stride)),
model_id=101, output_stride=output_stride,
stem_type=stem_type, multigrid=multigrid)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
......@@ -227,19 +239,21 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
output_size=[512, 512],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
output_size=[512, 512],
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512],
drop_remainder=False),
# resnet50
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400',
# resnet101
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
......
......@@ -38,10 +38,12 @@ class Decoder(decoder.Decoder):
class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
"""Parser to parse an image and its annotations into a dictionary of tensors.
"""
def __init__(self,
output_size,
train_on_crops=False,
resize_eval_groundtruth=True,
groundtruth_padded_size=None,
ignore_label=255,
......@@ -54,6 +56,9 @@ class Parser(parser.Parser):
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
train_on_crops: `bool`, if True, a training crop of size output_size
is returned. This is useful for cropping original images during training
while evaluating on original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
......@@ -70,6 +75,7 @@ class Parser(parser.Parser):
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
self._output_size = output_size
self._train_on_crops = train_on_crops
self._resize_eval_groundtruth = resize_eval_groundtruth
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
......@@ -104,9 +110,22 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation."""
image, label = self._prepare_image_and_label(data)
if self._train_on_crops:
if data['image/height'] < self._output_size[0] or data[
'image/width'] < self._output_size[1]:
raise ValueError(
'Image size has to be larger than crop size (output_size)')
label = tf.reshape(label, [data['image/height'], data['image/width'], 1])
image_mask = tf.concat([image, label], axis=2)
image_mask_crop = tf.image.random_crop(image_mask,
self._output_size + [4])
image = image_mask_crop[:, :, :-1]
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._output_size)
# Flips image randomly during training.
if self._aug_rand_hflip:
image, label = preprocess_ops.random_horizontal_flip(image, masks=label)
image, _, label = preprocess_ops.random_horizontal_flip(
image, masks=label)
# Resizes and crops image.
image, image_info = preprocess_ops.resize_and_crop_image(
......
......@@ -23,8 +23,9 @@ EPSILON = 1e-5
class SegmentationLoss:
"""Semantic segmentation loss."""
def __init__(self, label_smoothing, class_weights,
ignore_label, use_groundtruth_dimension):
def __init__(self, label_smoothing, class_weights, ignore_label,
use_groundtruth_dimension, top_k_percent_pixels=1.0):
self._top_k_percent_pixels = top_k_percent_pixels
self._class_weights = class_weights
self._ignore_label = ignore_label
self._use_groundtruth_dimension = use_groundtruth_dimension
......@@ -71,5 +72,18 @@ class SegmentationLoss:
tf.constant(class_weights, tf.float32))
valid_mask *= weight_mask
cross_entropy_loss *= tf.cast(valid_mask, tf.float32)
if self._top_k_percent_pixels >= 1.0:
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
else:
cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32) + EPSILON)
loss = tf.reduce_sum(top_k_losses) / normalizer
return loss
......@@ -56,6 +56,9 @@ class DilatedResNet(tf.keras.Model):
model_id,
output_stride,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
stem_type='v0',
multigrid=None,
last_stage_repeats=1,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
......@@ -70,6 +73,11 @@ class DilatedResNet(tf.keras.Model):
model_id: `int` depth of ResNet backbone model.
output_stride: `int` output stride, ratio of input to output resolution.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
stem_type: `standard` or `deeplab`, deeplab replaces 7x7 conv by 3 3x3
convs.
multigrid: `Tuple` of the same length as the number of blocks in the last
resnet stage.
last_stage_repeats: `int`, how many times last stage is repeated.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
......@@ -96,6 +104,7 @@ class DilatedResNet(tf.keras.Model):
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._stem_type = stem_type
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1
......@@ -105,8 +114,13 @@ class DilatedResNet(tf.keras.Model):
# Build ResNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
if stem_type == 'v0':
x = layers.Conv2D(
filters=64, kernel_size=7, strides=2, use_bias=False, padding='same',
filters=64,
kernel_size=7,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
......@@ -115,6 +129,52 @@ class DilatedResNet(tf.keras.Model):
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1':
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=128,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2
......@@ -137,7 +197,7 @@ class DilatedResNet(tf.keras.Model):
endpoints[str(i + 2)] = x
dilation_rate = 2
for i in range(normal_resnet_stage + 1, 7):
for i in range(normal_resnet_stage + 1, 3 + last_stage_repeats):
spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[model_id][-1]
if spec[0] == 'bottleneck':
block_fn = nn_blocks.BottleneckBlock
......@@ -150,6 +210,7 @@ class DilatedResNet(tf.keras.Model):
dilation_rate=dilation_rate,
block_fn=block_fn,
block_repeats=spec[2],
multigrid=multigrid if i >= 3 else None,
name='block_group_l{}'.format(i + 2))
dilation_rate *= 2
......@@ -167,9 +228,12 @@ class DilatedResNet(tf.keras.Model):
dilation_rate,
block_fn,
block_repeats=1,
multigrid=None,
name='block_group'):
"""Creates one group of blocks for the ResNet model.
Deeplab applies strides at the last block.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
......@@ -178,15 +242,24 @@ class DilatedResNet(tf.keras.Model):
dilation_rate: `int`, diluted convolution rates.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer.
multigrid: List of ints or None, if specified, dilation rates for each
block is scaled up by its corresponding factor in the multigrid.
name: `str`name for the block.
Returns:
The output `Tensor` of the block layer.
"""
if multigrid is not None and len(multigrid) != block_repeats:
raise ValueError('multigrid has to match number of block_repeats')
if multigrid is None:
multigrid = [1] * block_repeats
# TODO(arashwan): move striding at the of the block.
x = block_fn(
filters=filters,
strides=strides,
dilation_rate=dilation_rate,
dilation_rate=dilation_rate * multigrid[0],
use_projection=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -196,12 +269,11 @@ class DilatedResNet(tf.keras.Model):
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
for i in range(1, block_repeats):
x = block_fn(
filters=filters,
strides=1,
dilation_rate=dilation_rate,
dilation_rate=dilation_rate * multigrid[i],
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
......@@ -254,6 +326,9 @@ def build_dilated_resnet(
model_id=backbone_cfg.model_id,
output_stride=backbone_cfg.output_stride,
input_specs=input_specs,
multigrid=backbone_cfg.multigrid,
last_stage_repeats=backbone_cfg.last_stage_repeats,
stem_type=backbone_cfg.stem_type,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
......
......@@ -28,6 +28,7 @@ class ASPP(tf.keras.layers.Layer):
level,
dilation_rates,
num_filters=256,
pool_kernel_size=None,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
......@@ -43,6 +44,9 @@ class ASPP(tf.keras.layers.Layer):
level: `int` level to apply ASPP.
dilation_rates: `list` of dilation rates.
num_filters: `int` number of output filters in ASPP.
pool_kernel_size: `list` of [height, width] of pooling kernel size or
None. Pooling size is with respect to original image size, it will be
scaled down by 2**level. If None, global average pooling is used.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
......@@ -60,6 +64,7 @@ class ASPP(tf.keras.layers.Layer):
'level': level,
'dilation_rates': dilation_rates,
'num_filters': num_filters,
'pool_kernel_size': pool_kernel_size,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
......@@ -71,9 +76,16 @@ class ASPP(tf.keras.layers.Layer):
}
def build(self, input_shape):
pool_kernel_size = None
if self._config_dict['pool_kernel_size']:
pool_kernel_size = [
int(p_size // 2**self._config_dict['level'])
for p_size in self._config_dict['pool_kernel_size']
]
self.aspp = keras_cv.layers.SpatialPyramidPooling(
output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'],
pool_kernel_size=pool_kernel_size,
use_sync_bn=self._config_dict['use_sync_bn'],
batchnorm_momentum=self._config_dict['norm_momentum'],
batchnorm_epsilon=self._config_dict['norm_epsilon'],
......
......@@ -61,6 +61,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
level=3,
dilation_rates=[6, 12],
num_filters=256,
pool_kernel_size=None,
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
......
......@@ -70,6 +70,7 @@ def build_decoder(input_specs,
level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters,
pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
......
......@@ -102,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
conv_kwargs = {
'kernel_size': 3,
'padding': 'same',
'bias_initializer': tf.zeros_initializer(),
'use_bias': False,
'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
}
......@@ -120,7 +120,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self._dlv3p_conv = conv_op(
kernel_size=1,
padding='same',
bias_initializer=tf.zeros_initializer(),
use_bias=False,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv',
......@@ -145,7 +145,12 @@ class SegmentationHead(tf.keras.layers.Layer):
self._classifier = conv_op(
name='segmentation_output',
filters=self._config_dict['num_classes'],
**conv_kwargs)
kernel_size=1,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
super(SegmentationHead, self).build(input_shape)
......
......@@ -81,17 +81,18 @@ class SemanticSegmentationTask(base_task.Task):
def build_inputs(self, params, input_context=None):
"""Builds classification input."""
input_size = self.task_config.model.input_size
ignore_label = self.task_config.losses.ignore_label
decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser(
output_size=input_size[:2],
output_size=params.output_size,
train_on_crops=params.train_on_crops,
ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth,
groundtruth_padded_size=params.groundtruth_padded_size,
aug_scale_min=params.aug_scale_min,
aug_scale_max=params.aug_scale_max,
aug_rand_hflip=params.aug_rand_hflip,
dtype=params.dtype)
reader = input_reader.InputReader(
......@@ -120,7 +121,8 @@ class SemanticSegmentationTask(base_task.Task):
loss_params.label_smoothing,
loss_params.class_weights,
loss_params.ignore_label,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension)
use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels)
total_loss = segmentation_loss_fn(model_outputs, labels['masks'])
......@@ -133,19 +135,18 @@ class SemanticSegmentationTask(base_task.Task):
"""Gets streaming metrics for training/validation."""
metrics = []
if training:
# TODO(arashwan): make MeanIoU tpu friendly.
if not isinstance(tf.distribute.get_strategy(),
tf.distribute.TPUStrategy):
metrics.append(segmentation_metrics.MeanIoU(
name='mean_iou',
num_classes=self.task_config.model.num_classes,
rescale_predictions=False))
rescale_predictions=False,
dtype=tf.float32))
else:
self.miou_metric = segmentation_metrics.MeanIoU(
name='val_mean_iou',
num_classes=self.task_config.model.num_classes,
rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth)
.resize_eval_groundtruth,
dtype=tf.float32)
return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment