Commit f8ec1e39 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 343888044
parent 1b5a4c9e
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Backbones configurations.""" """Backbones configurations."""
from typing import Optional from typing import Optional, List
# Import libraries # Import libraries
import dataclasses import dataclasses
...@@ -36,6 +36,9 @@ class DilatedResNet(hyperparams.Config): ...@@ -36,6 +36,9 @@ class DilatedResNet(hyperparams.Config):
"""DilatedResNet config.""" """DilatedResNet config."""
model_id: int = 50 model_id: int = 50
output_stride: int = 16 output_stride: int = 16
multigrid: Optional[List[int]] = None
stem_type: str = 'v0'
last_stage_repeats: int = 1
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config): ...@@ -50,6 +50,7 @@ class ASPP(hyperparams.Config):
dilation_rates: List[int] = dataclasses.field(default_factory=list) dilation_rates: List[int] = dataclasses.field(default_factory=list)
dropout_rate: float = 0.0 dropout_rate: float = 0.0
num_filters: int = 256 num_filters: int = 256
pool_kernel_size: Optional[List[int]] = None # Use global average pooling.
@dataclasses.dataclass @dataclasses.dataclass
......
# Top1 accuracy 80.36%
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16' mixed_precision_dtype: 'bfloat16'
...@@ -10,6 +11,9 @@ task: ...@@ -10,6 +11,9 @@ task:
dilated_resnet: dilated_resnet:
model_id: 101 model_id: 101
output_stride: 16 output_stride: 16
stem_type: 'v1'
multigrid: [1, 2, 4]
last_stage_repeats: 1
norm_activation: norm_activation:
activation: 'swish' activation: 'swish'
losses: losses:
......
# Dilated ResNet-101 Pascal segmentation. 80.89 mean IOU.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 8
norm_activation:
activation: 'swish'
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 8
norm_activation:
activation: 'swish'
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
# Dilated ResNet-101 Pascal segmentation. 80.83 mean IOU with output stride of 16.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 16
head:
feature_fusion: 'deeplabv3plus'
low_level: 2
low_level_num_filters: 48
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
...@@ -32,6 +32,8 @@ from official.vision.beta.configs import decoders ...@@ -32,6 +32,8 @@ from official.vision.beta.configs import decoders
@dataclasses.dataclass @dataclasses.dataclass
class DataConfig(cfg.DataConfig): class DataConfig(cfg.DataConfig):
"""Input config for training.""" """Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
train_on_crops: bool = False
input_path: str = '' input_path: str = ''
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = True is_training: bool = True
...@@ -42,6 +44,7 @@ class DataConfig(cfg.DataConfig): ...@@ -42,6 +44,7 @@ class DataConfig(cfg.DataConfig):
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list) groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
drop_remainder: bool = True drop_remainder: bool = True
...@@ -73,11 +76,12 @@ class SemanticSegmentationModel(hyperparams.Config): ...@@ -73,11 +76,12 @@ class SemanticSegmentationModel(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Losses(hyperparams.Config): class Losses(hyperparams.Config):
label_smoothing: float = 0.1 label_smoothing: float = 0.0
ignore_label: int = 255 ignore_label: int = 255
class_weights: List[float] = dataclasses.field(default_factory=list) class_weights: List[float] = dataclasses.field(default_factory=list)
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True use_groundtruth_dimension: bool = True
top_k_percent_pixels: float = 1.0
@dataclasses.dataclass @dataclasses.dataclass
...@@ -115,18 +119,20 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -115,18 +119,20 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
train_batch_size = 16 train_batch_size = 16
eval_batch_size = 8 eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 8 output_stride = 16
aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16 aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16
multigrid = [1, 2, 4]
stem_type = 'v1'
level = int(np.math.log2(output_stride)) level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=SemanticSegmentationTask( task=SemanticSegmentationTask(
model=SemanticSegmentationModel( model=SemanticSegmentationModel(
num_classes=21, num_classes=21,
# TODO(arashwan): test changing size to 513 to match deeplab. input_size=[None, None, 3],
input_size=[512, 512, 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=output_stride)), model_id=101, output_stride=output_stride,
multigrid=multigrid, stem_type=stem_type)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', aspp=decoders.ASPP( type='aspp', aspp=decoders.ASPP(
level=level, dilation_rates=aspp_dilation_rates)), level=level, dilation_rates=aspp_dilation_rates)),
...@@ -139,19 +145,22 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: ...@@ -139,19 +145,22 @@ def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
losses=Losses(l2_weight_decay=1e-4), losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
# TODO(arashwan): test changing size to 513 to match deeplab.
output_size=[512, 512],
is_training=True, is_training=True,
global_batch_size=train_batch_size, global_batch_size=train_batch_size,
aug_scale_min=0.5, aug_scale_min=0.5,
aug_scale_max=2.0), aug_scale_max=2.0),
validation_data=DataConfig( validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
output_size=[512, 512],
is_training=False, is_training=False,
global_batch_size=eval_batch_size, global_batch_size=eval_batch_size,
resize_eval_groundtruth=False, resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512], groundtruth_padded_size=[512, 512],
drop_remainder=False), drop_remainder=False),
# resnet50 # resnet101
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400', init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'), init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch, steps_per_loop=steps_per_epoch,
...@@ -199,16 +208,19 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: ...@@ -199,16 +208,19 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
eval_batch_size = 8 eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
output_stride = 16 output_stride = 16
aspp_dilation_rates = [6, 12, 18] # [12, 24, 36] if output_stride = 8 aspp_dilation_rates = [6, 12, 18]
multigrid = [1, 2, 4]
stem_type = 'v1'
level = int(np.math.log2(output_stride)) level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=SemanticSegmentationTask( task=SemanticSegmentationTask(
model=SemanticSegmentationModel( model=SemanticSegmentationModel(
num_classes=21, num_classes=21,
input_size=[512, 512, 3], input_size=[None, None, 3],
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=output_stride)), model_id=101, output_stride=output_stride,
stem_type=stem_type, multigrid=multigrid)),
decoder=decoders.Decoder( decoder=decoders.Decoder(
type='aspp', type='aspp',
aspp=decoders.ASPP( aspp=decoders.ASPP(
...@@ -227,19 +239,21 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: ...@@ -227,19 +239,21 @@ def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig:
losses=Losses(l2_weight_decay=1e-4), losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
output_size=[512, 512],
is_training=True, is_training=True,
global_batch_size=train_batch_size, global_batch_size=train_batch_size,
aug_scale_min=0.5, aug_scale_min=0.5,
aug_scale_max=2.0), aug_scale_max=2.0),
validation_data=DataConfig( validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
output_size=[512, 512],
is_training=False, is_training=False,
global_batch_size=eval_batch_size, global_batch_size=eval_batch_size,
resize_eval_groundtruth=False, resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512], groundtruth_padded_size=[512, 512],
drop_remainder=False), drop_remainder=False),
# resnet50 # resnet101
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400', init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'), init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch, steps_per_loop=steps_per_epoch,
......
...@@ -38,10 +38,12 @@ class Decoder(decoder.Decoder): ...@@ -38,10 +38,12 @@ class Decoder(decoder.Decoder):
class Parser(parser.Parser): class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors.""" """Parser to parse an image and its annotations into a dictionary of tensors.
"""
def __init__(self, def __init__(self,
output_size, output_size,
train_on_crops=False,
resize_eval_groundtruth=True, resize_eval_groundtruth=True,
groundtruth_padded_size=None, groundtruth_padded_size=None,
ignore_label=255, ignore_label=255,
...@@ -54,6 +56,9 @@ class Parser(parser.Parser): ...@@ -54,6 +56,9 @@ class Parser(parser.Parser):
Args: Args:
output_size: `Tensor` or `list` for [height, width] of output image. The output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level. output_size should be divided by the largest feature stride 2^max_level.
train_on_crops: `bool`, if True, a training crop of size output_size
is returned. This is useful for cropping original images during training
while evaluating on original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size. resized to output_size.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
...@@ -70,6 +75,7 @@ class Parser(parser.Parser): ...@@ -70,6 +75,7 @@ class Parser(parser.Parser):
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}. dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
""" """
self._output_size = output_size self._output_size = output_size
self._train_on_crops = train_on_crops
self._resize_eval_groundtruth = resize_eval_groundtruth self._resize_eval_groundtruth = resize_eval_groundtruth
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None): if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
raise ValueError('groundtruth_padded_size ([height, width]) needs to be' raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
...@@ -104,9 +110,22 @@ class Parser(parser.Parser): ...@@ -104,9 +110,22 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation.""" """Parses data for training and evaluation."""
image, label = self._prepare_image_and_label(data) image, label = self._prepare_image_and_label(data)
if self._train_on_crops:
if data['image/height'] < self._output_size[0] or data[
'image/width'] < self._output_size[1]:
raise ValueError(
'Image size has to be larger than crop size (output_size)')
label = tf.reshape(label, [data['image/height'], data['image/width'], 1])
image_mask = tf.concat([image, label], axis=2)
image_mask_crop = tf.image.random_crop(image_mask,
self._output_size + [4])
image = image_mask_crop[:, :, :-1]
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._output_size)
# Flips image randomly during training. # Flips image randomly during training.
if self._aug_rand_hflip: if self._aug_rand_hflip:
image, label = preprocess_ops.random_horizontal_flip(image, masks=label) image, _, label = preprocess_ops.random_horizontal_flip(
image, masks=label)
# Resizes and crops image. # Resizes and crops image.
image, image_info = preprocess_ops.resize_and_crop_image( image, image_info = preprocess_ops.resize_and_crop_image(
......
...@@ -23,8 +23,9 @@ EPSILON = 1e-5 ...@@ -23,8 +23,9 @@ EPSILON = 1e-5
class SegmentationLoss: class SegmentationLoss:
"""Semantic segmentation loss.""" """Semantic segmentation loss."""
def __init__(self, label_smoothing, class_weights, def __init__(self, label_smoothing, class_weights, ignore_label,
ignore_label, use_groundtruth_dimension): use_groundtruth_dimension, top_k_percent_pixels=1.0):
self._top_k_percent_pixels = top_k_percent_pixels
self._class_weights = class_weights self._class_weights = class_weights
self._ignore_label = ignore_label self._ignore_label = ignore_label
self._use_groundtruth_dimension = use_groundtruth_dimension self._use_groundtruth_dimension = use_groundtruth_dimension
...@@ -71,5 +72,18 @@ class SegmentationLoss: ...@@ -71,5 +72,18 @@ class SegmentationLoss:
tf.constant(class_weights, tf.float32)) tf.constant(class_weights, tf.float32))
valid_mask *= weight_mask valid_mask *= weight_mask
cross_entropy_loss *= tf.cast(valid_mask, tf.float32) cross_entropy_loss *= tf.cast(valid_mask, tf.float32)
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
if self._top_k_percent_pixels >= 1.0:
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
else:
cross_entropy_loss = tf.reshape(cross_entropy_loss, shape=[-1])
top_k_pixels = tf.cast(
self._top_k_percent_pixels *
tf.cast(tf.size(cross_entropy_loss), tf.float32), tf.int32)
top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32) + EPSILON)
loss = tf.reduce_sum(top_k_losses) / normalizer
return loss return loss
...@@ -56,6 +56,9 @@ class DilatedResNet(tf.keras.Model): ...@@ -56,6 +56,9 @@ class DilatedResNet(tf.keras.Model):
model_id, model_id,
output_stride, output_stride,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), input_specs=layers.InputSpec(shape=[None, None, None, 3]),
stem_type='v0',
multigrid=None,
last_stage_repeats=1,
activation='relu', activation='relu',
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
...@@ -70,6 +73,11 @@ class DilatedResNet(tf.keras.Model): ...@@ -70,6 +73,11 @@ class DilatedResNet(tf.keras.Model):
model_id: `int` depth of ResNet backbone model. model_id: `int` depth of ResNet backbone model.
output_stride: `int` output stride, ratio of input to output resolution. output_stride: `int` output stride, ratio of input to output resolution.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
stem_type: `standard` or `deeplab`, deeplab replaces 7x7 conv by 3 3x3
convs.
multigrid: `Tuple` of the same length as the number of blocks in the last
resnet stage.
last_stage_repeats: `int`, how many times last stage is repeated.
activation: `str` name of the activation function. activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: `float` normalization omentum for the moving average.
...@@ -96,6 +104,7 @@ class DilatedResNet(tf.keras.Model): ...@@ -96,6 +104,7 @@ class DilatedResNet(tf.keras.Model):
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
self._stem_type = stem_type
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1 bn_axis = -1
...@@ -105,16 +114,67 @@ class DilatedResNet(tf.keras.Model): ...@@ -105,16 +114,67 @@ class DilatedResNet(tf.keras.Model):
# Build ResNet. # Build ResNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = layers.Conv2D( if stem_type == 'v0':
filters=64, kernel_size=7, strides=2, use_bias=False, padding='same', x = layers.Conv2D(
kernel_initializer=self._kernel_initializer, filters=64,
kernel_regularizer=self._kernel_regularizer, kernel_size=7,
bias_regularizer=self._bias_regularizer)( strides=2,
inputs) use_bias=False,
x = self._norm( padding='same',
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( kernel_initializer=self._kernel_initializer,
x) kernel_regularizer=self._kernel_regularizer,
x = tf_utils.get_activation(activation)(x) bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
elif stem_type == 'v1':
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=64,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
x = layers.Conv2D(
filters=128,
kernel_size=3,
strides=1,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x) x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2 normal_resnet_stage = int(np.math.log2(self._output_stride)) - 2
...@@ -137,7 +197,7 @@ class DilatedResNet(tf.keras.Model): ...@@ -137,7 +197,7 @@ class DilatedResNet(tf.keras.Model):
endpoints[str(i + 2)] = x endpoints[str(i + 2)] = x
dilation_rate = 2 dilation_rate = 2
for i in range(normal_resnet_stage + 1, 7): for i in range(normal_resnet_stage + 1, 3 + last_stage_repeats):
spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[model_id][-1] spec = RESNET_SPECS[model_id][i] if i < 3 else RESNET_SPECS[model_id][-1]
if spec[0] == 'bottleneck': if spec[0] == 'bottleneck':
block_fn = nn_blocks.BottleneckBlock block_fn = nn_blocks.BottleneckBlock
...@@ -150,6 +210,7 @@ class DilatedResNet(tf.keras.Model): ...@@ -150,6 +210,7 @@ class DilatedResNet(tf.keras.Model):
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
block_fn=block_fn, block_fn=block_fn,
block_repeats=spec[2], block_repeats=spec[2],
multigrid=multigrid if i >= 3 else None,
name='block_group_l{}'.format(i + 2)) name='block_group_l{}'.format(i + 2))
dilation_rate *= 2 dilation_rate *= 2
...@@ -167,9 +228,12 @@ class DilatedResNet(tf.keras.Model): ...@@ -167,9 +228,12 @@ class DilatedResNet(tf.keras.Model):
dilation_rate, dilation_rate,
block_fn, block_fn,
block_repeats=1, block_repeats=1,
multigrid=None,
name='block_group'): name='block_group'):
"""Creates one group of blocks for the ResNet model. """Creates one group of blocks for the ResNet model.
Deeplab applies strides at the last block.
Args: Args:
inputs: `Tensor` of size `[batch, channels, height, width]`. inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer. filters: `int` number of filters for the first convolution of the layer.
...@@ -178,15 +242,24 @@ class DilatedResNet(tf.keras.Model): ...@@ -178,15 +242,24 @@ class DilatedResNet(tf.keras.Model):
dilation_rate: `int`, diluted convolution rates. dilation_rate: `int`, diluted convolution rates.
block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`. block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
block_repeats: `int` number of blocks contained in the layer. block_repeats: `int` number of blocks contained in the layer.
multigrid: List of ints or None, if specified, dilation rates for each
block is scaled up by its corresponding factor in the multigrid.
name: `str`name for the block. name: `str`name for the block.
Returns: Returns:
The output `Tensor` of the block layer. The output `Tensor` of the block layer.
""" """
if multigrid is not None and len(multigrid) != block_repeats:
raise ValueError('multigrid has to match number of block_repeats')
if multigrid is None:
multigrid = [1] * block_repeats
# TODO(arashwan): move striding at the of the block.
x = block_fn( x = block_fn(
filters=filters, filters=filters,
strides=strides, strides=strides,
dilation_rate=dilation_rate, dilation_rate=dilation_rate * multigrid[0],
use_projection=True, use_projection=True,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -196,12 +269,11 @@ class DilatedResNet(tf.keras.Model): ...@@ -196,12 +269,11 @@ class DilatedResNet(tf.keras.Model):
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)( norm_epsilon=self._norm_epsilon)(
inputs) inputs)
for i in range(1, block_repeats):
for _ in range(1, block_repeats):
x = block_fn( x = block_fn(
filters=filters, filters=filters,
strides=1, strides=1,
dilation_rate=dilation_rate, dilation_rate=dilation_rate * multigrid[i],
use_projection=False, use_projection=False,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
...@@ -254,6 +326,9 @@ def build_dilated_resnet( ...@@ -254,6 +326,9 @@ def build_dilated_resnet(
model_id=backbone_cfg.model_id, model_id=backbone_cfg.model_id,
output_stride=backbone_cfg.output_stride, output_stride=backbone_cfg.output_stride,
input_specs=input_specs, input_specs=input_specs,
multigrid=backbone_cfg.multigrid,
last_stage_repeats=backbone_cfg.last_stage_repeats,
stem_type=backbone_cfg.stem_type,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -28,6 +28,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -28,6 +28,7 @@ class ASPP(tf.keras.layers.Layer):
level, level,
dilation_rates, dilation_rates,
num_filters=256, num_filters=256,
pool_kernel_size=None,
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
...@@ -43,6 +44,9 @@ class ASPP(tf.keras.layers.Layer): ...@@ -43,6 +44,9 @@ class ASPP(tf.keras.layers.Layer):
level: `int` level to apply ASPP. level: `int` level to apply ASPP.
dilation_rates: `list` of dilation rates. dilation_rates: `list` of dilation rates.
num_filters: `int` number of output filters in ASPP. num_filters: `int` number of output filters in ASPP.
pool_kernel_size: `list` of [height, width] of pooling kernel size or
None. Pooling size is with respect to original image size, it will be
scaled down by 2**level. If None, global average pooling is used.
use_sync_bn: if True, use synchronized batch normalization. use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average. norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: `float` small float added to variance to avoid dividing by
...@@ -60,6 +64,7 @@ class ASPP(tf.keras.layers.Layer): ...@@ -60,6 +64,7 @@ class ASPP(tf.keras.layers.Layer):
'level': level, 'level': level,
'dilation_rates': dilation_rates, 'dilation_rates': dilation_rates,
'num_filters': num_filters, 'num_filters': num_filters,
'pool_kernel_size': pool_kernel_size,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum, 'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon, 'norm_epsilon': norm_epsilon,
...@@ -71,9 +76,16 @@ class ASPP(tf.keras.layers.Layer): ...@@ -71,9 +76,16 @@ class ASPP(tf.keras.layers.Layer):
} }
def build(self, input_shape): def build(self, input_shape):
pool_kernel_size = None
if self._config_dict['pool_kernel_size']:
pool_kernel_size = [
int(p_size // 2**self._config_dict['level'])
for p_size in self._config_dict['pool_kernel_size']
]
self.aspp = keras_cv.layers.SpatialPyramidPooling( self.aspp = keras_cv.layers.SpatialPyramidPooling(
output_channels=self._config_dict['num_filters'], output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'], dilation_rates=self._config_dict['dilation_rates'],
pool_kernel_size=pool_kernel_size,
use_sync_bn=self._config_dict['use_sync_bn'], use_sync_bn=self._config_dict['use_sync_bn'],
batchnorm_momentum=self._config_dict['norm_momentum'], batchnorm_momentum=self._config_dict['norm_momentum'],
batchnorm_epsilon=self._config_dict['norm_epsilon'], batchnorm_epsilon=self._config_dict['norm_epsilon'],
......
...@@ -61,6 +61,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -61,6 +61,7 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
level=3, level=3,
dilation_rates=[6, 12], dilation_rates=[6, 12],
num_filters=256, num_filters=256,
pool_kernel_size=None,
use_sync_bn=False, use_sync_bn=False,
norm_momentum=0.99, norm_momentum=0.99,
norm_epsilon=0.001, norm_epsilon=0.001,
......
...@@ -70,6 +70,7 @@ def build_decoder(input_specs, ...@@ -70,6 +70,7 @@ def build_decoder(input_specs,
level=decoder_cfg.level, level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates, dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters, num_filters=decoder_cfg.num_filters,
pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate, dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -102,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -102,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
conv_kwargs = { conv_kwargs = {
'kernel_size': 3, 'kernel_size': 3,
'padding': 'same', 'padding': 'same',
'bias_initializer': tf.zeros_initializer(), 'use_bias': False,
'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01), 'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01),
'kernel_regularizer': self._config_dict['kernel_regularizer'], 'kernel_regularizer': self._config_dict['kernel_regularizer'],
} }
...@@ -120,7 +120,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -120,7 +120,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self._dlv3p_conv = conv_op( self._dlv3p_conv = conv_op(
kernel_size=1, kernel_size=1,
padding='same', padding='same',
bias_initializer=tf.zeros_initializer(), use_bias=False,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
name='segmentation_head_deeplabv3p_fusion_conv', name='segmentation_head_deeplabv3p_fusion_conv',
...@@ -145,7 +145,12 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -145,7 +145,12 @@ class SegmentationHead(tf.keras.layers.Layer):
self._classifier = conv_op( self._classifier = conv_op(
name='segmentation_output', name='segmentation_output',
filters=self._config_dict['num_classes'], filters=self._config_dict['num_classes'],
**conv_kwargs) kernel_size=1,
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
super(SegmentationHead, self).build(input_shape) super(SegmentationHead, self).build(input_shape)
......
...@@ -81,17 +81,18 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -81,17 +81,18 @@ class SemanticSegmentationTask(base_task.Task):
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Builds classification input.""" """Builds classification input."""
input_size = self.task_config.model.input_size
ignore_label = self.task_config.losses.ignore_label ignore_label = self.task_config.losses.ignore_label
decoder = segmentation_input.Decoder() decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser( parser = segmentation_input.Parser(
output_size=input_size[:2], output_size=params.output_size,
train_on_crops=params.train_on_crops,
ignore_label=ignore_label, ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth, resize_eval_groundtruth=params.resize_eval_groundtruth,
groundtruth_padded_size=params.groundtruth_padded_size, groundtruth_padded_size=params.groundtruth_padded_size,
aug_scale_min=params.aug_scale_min, aug_scale_min=params.aug_scale_min,
aug_scale_max=params.aug_scale_max, aug_scale_max=params.aug_scale_max,
aug_rand_hflip=params.aug_rand_hflip,
dtype=params.dtype) dtype=params.dtype)
reader = input_reader.InputReader( reader = input_reader.InputReader(
...@@ -120,7 +121,8 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -120,7 +121,8 @@ class SemanticSegmentationTask(base_task.Task):
loss_params.label_smoothing, loss_params.label_smoothing,
loss_params.class_weights, loss_params.class_weights,
loss_params.ignore_label, loss_params.ignore_label,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension) use_groundtruth_dimension=loss_params.use_groundtruth_dimension,
top_k_percent_pixels=loss_params.top_k_percent_pixels)
total_loss = segmentation_loss_fn(model_outputs, labels['masks']) total_loss = segmentation_loss_fn(model_outputs, labels['masks'])
...@@ -133,19 +135,18 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -133,19 +135,18 @@ class SemanticSegmentationTask(base_task.Task):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
metrics = [] metrics = []
if training: if training:
# TODO(arashwan): make MeanIoU tpu friendly. metrics.append(segmentation_metrics.MeanIoU(
if not isinstance(tf.distribute.get_strategy(), name='mean_iou',
tf.distribute.TPUStrategy): num_classes=self.task_config.model.num_classes,
metrics.append(segmentation_metrics.MeanIoU( rescale_predictions=False,
name='mean_iou', dtype=tf.float32))
num_classes=self.task_config.model.num_classes,
rescale_predictions=False))
else: else:
self.miou_metric = segmentation_metrics.MeanIoU( self.miou_metric = segmentation_metrics.MeanIoU(
name='val_mean_iou', name='val_mean_iou',
num_classes=self.task_config.model.num_classes, num_classes=self.task_config.model.num_classes,
rescale_predictions=not self.task_config.validation_data rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth) .resize_eval_groundtruth,
dtype=tf.float32)
return metrics return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment