Commit 6f917c69 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368036370
parent f281248c
...@@ -31,8 +31,8 @@ task: ...@@ -31,8 +31,8 @@ task:
losses: losses:
top_k_percent_pixels: 1.0 # only backpropagate loss for the topk 100% pixels. top_k_percent_pixels: 1.0 # only backpropagate loss for the topk 100% pixels.
train_data: train_data:
output_size: [512, 1024] output_size: [1024, 2048]
train_on_crops: true crop_size: [512, 1024]
input_path: '' input_path: ''
tfds_name: 'cityscapes/semantic_segmentation' tfds_name: 'cityscapes/semantic_segmentation'
tfds_split: 'train' tfds_split: 'train'
......
...@@ -33,9 +33,9 @@ from official.vision.beta.configs import decoders ...@@ -33,9 +33,9 @@ from official.vision.beta.configs import decoders
class DataConfig(cfg.DataConfig): class DataConfig(cfg.DataConfig):
"""Input config for training.""" """Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list) output_size: List[int] = dataclasses.field(default_factory=list)
# If train_on_crops is set to True, a patch of size output_size is cropped # If crop_size is specified, image will be resized first to
# from the input image. # output_size, then crop of size crop_size will be cropped.
train_on_crops: bool = False crop_size: List[int] = dataclasses.field(default_factory=list)
input_path: str = '' input_path: str = ''
global_batch_size: int = 0 global_batch_size: int = 0
is_training: bool = True is_training: bool = True
...@@ -56,9 +56,11 @@ class DataConfig(cfg.DataConfig): ...@@ -56,9 +56,11 @@ class DataConfig(cfg.DataConfig):
@dataclasses.dataclass @dataclasses.dataclass
class SegmentationHead(hyperparams.Config): class SegmentationHead(hyperparams.Config):
"""Segmentation head config."""
level: int = 3 level: int = 3
num_convs: int = 2 num_convs: int = 2
num_filters: int = 256 num_filters: int = 256
prediction_kernel_size: int = 1
upsample_factor: int = 1 upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
# deeplabv3plus feature fusion params # deeplabv3plus feature fusion params
...@@ -433,8 +435,8 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: ...@@ -433,8 +435,8 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE,
'train_fine**'), 'train_fine**'),
output_size=[512, 1024], crop_size=[512, 1024],
train_on_crops=True, output_size=[1024, 2048],
is_training=True, is_training=True,
global_batch_size=train_batch_size, global_batch_size=train_batch_size,
aug_scale_min=0.5, aug_scale_min=0.5,
......
...@@ -43,7 +43,7 @@ class Parser(parser.Parser): ...@@ -43,7 +43,7 @@ class Parser(parser.Parser):
def __init__(self, def __init__(self,
output_size, output_size,
train_on_crops=False, crop_size=None,
resize_eval_groundtruth=True, resize_eval_groundtruth=True,
groundtruth_padded_size=None, groundtruth_padded_size=None,
ignore_label=255, ignore_label=255,
...@@ -56,9 +56,10 @@ class Parser(parser.Parser): ...@@ -56,9 +56,10 @@ class Parser(parser.Parser):
Args: Args:
output_size: `Tensor` or `list` for [height, width] of output image. The output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level. output_size should be divided by the largest feature stride 2^max_level.
train_on_crops: `bool`, if True, a training crop of size output_size crop_size: `Tensor` or `list` for [height, width] of the crop. If
is returned. This is useful for cropping original images during training specified a training crop of size crop_size is returned. This is useful
while evaluating on original image sizes. for cropping original images during training while evaluating on
original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size. resized to output_size.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
...@@ -75,7 +76,7 @@ class Parser(parser.Parser): ...@@ -75,7 +76,7 @@ class Parser(parser.Parser):
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}. dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
""" """
self._output_size = output_size self._output_size = output_size
self._train_on_crops = train_on_crops self._crop_size = crop_size
self._resize_eval_groundtruth = resize_eval_groundtruth self._resize_eval_groundtruth = resize_eval_groundtruth
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None): if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
raise ValueError('groundtruth_padded_size ([height, width]) needs to be' raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
...@@ -110,24 +111,32 @@ class Parser(parser.Parser): ...@@ -110,24 +111,32 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation.""" """Parses data for training and evaluation."""
image, label = self._prepare_image_and_label(data) image, label = self._prepare_image_and_label(data)
if self._train_on_crops: if self._crop_size:
label = tf.reshape(label, [data['image/height'], data['image/width'], 1]) label = tf.reshape(label, [data['image/height'], data['image/width'], 1])
# If output_size is specified, resize image, and label to desired
# output_size.
if self._output_size:
image = tf.image.resize(image, self._output_size, method='bilinear')
label = tf.image.resize(label, self._output_size, method='nearest')
image_mask = tf.concat([image, label], axis=2) image_mask = tf.concat([image, label], axis=2)
image_mask_crop = tf.image.random_crop(image_mask, image_mask_crop = tf.image.random_crop(image_mask,
self._output_size + [4]) self._crop_size + [4])
image = image_mask_crop[:, :, :-1] image = image_mask_crop[:, :, :-1]
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._output_size) label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._crop_size)
# Flips image randomly during training. # Flips image randomly during training.
if self._aug_rand_hflip: if self._aug_rand_hflip:
image, _, label = preprocess_ops.random_horizontal_flip( image, _, label = preprocess_ops.random_horizontal_flip(
image, masks=label) image, masks=label)
train_image_size = self._crop_size if self._crop_size else self._output_size
# Resizes and crops image. # Resizes and crops image.
image, image_info = preprocess_ops.resize_and_crop_image( image, image_info = preprocess_ops.resize_and_crop_image(
image, image,
self._output_size, train_image_size,
self._output_size, train_image_size,
aug_scale_min=self._aug_scale_min, aug_scale_min=self._aug_scale_min,
aug_scale_max=self._aug_scale_max) aug_scale_max=self._aug_scale_max)
...@@ -140,7 +149,7 @@ class Parser(parser.Parser): ...@@ -140,7 +149,7 @@ class Parser(parser.Parser):
label += 1 label += 1
label = tf.expand_dims(label, axis=3) label = tf.expand_dims(label, axis=3)
label = preprocess_ops.resize_and_crop_masks( label = preprocess_ops.resize_and_crop_masks(
label, image_scale, self._output_size, offset) label, image_scale, train_image_size, offset)
label -= 1 label -= 1
label = tf.where(tf.equal(label, -1), label = tf.where(tf.equal(label, -1),
self._ignore_label * tf.ones_like(label), label) self._ignore_label * tf.ones_like(label), label)
......
...@@ -272,6 +272,7 @@ def build_segmentation_model( ...@@ -272,6 +272,7 @@ def build_segmentation_model(
num_classes=model_config.num_classes, num_classes=model_config.num_classes,
level=head_config.level, level=head_config.level,
num_convs=head_config.num_convs, num_convs=head_config.num_convs,
prediction_kernel_size=head_config.prediction_kernel_size,
num_filters=head_config.num_filters, num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor, upsample_factor=head_config.upsample_factor,
feature_fusion=head_config.feature_fusion, feature_fusion=head_config.feature_fusion,
......
...@@ -30,6 +30,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -30,6 +30,7 @@ class SegmentationHead(tf.keras.layers.Layer):
level, level,
num_convs=2, num_convs=2,
num_filters=256, num_filters=256,
prediction_kernel_size=1,
upsample_factor=1, upsample_factor=1,
feature_fusion=None, feature_fusion=None,
low_level=2, low_level=2,
...@@ -51,6 +52,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -51,6 +52,8 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer. prediction layer.
num_filters: An `int` number to specify the number of filters used. num_filters: An `int` number to specify the number of filters used.
Default is 256. Default is 256.
prediction_kernel_size: An `int` number to specify the kernel size of the
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied. generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If
...@@ -80,6 +83,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -80,6 +83,7 @@ class SegmentationHead(tf.keras.layers.Layer):
'level': level, 'level': level,
'num_convs': num_convs, 'num_convs': num_convs,
'num_filters': num_filters, 'num_filters': num_filters,
'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor, 'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion, 'feature_fusion': feature_fusion,
'low_level': low_level, 'low_level': low_level,
...@@ -146,7 +150,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -146,7 +150,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self._classifier = conv_op( self._classifier = conv_op(
name='segmentation_output', name='segmentation_output',
filters=self._config_dict['num_classes'], filters=self._config_dict['num_classes'],
kernel_size=1, kernel_size=self._config_dict['prediction_kernel_size'],
padding='same', padding='same',
bias_initializer=tf.zeros_initializer(), bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
...@@ -193,8 +197,10 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -193,8 +197,10 @@ class SegmentationHead(tf.keras.layers.Layer):
x = conv(x) x = conv(x)
x = norm(x) x = norm(x)
x = self._activation(x) x = self._activation(x)
x = spatial_transform_ops.nearest_upsampling( if self._config_dict['upsample_factor'] > 1:
x, scale=self._config_dict['upsample_factor']) x = spatial_transform_ops.nearest_upsampling(
x, scale=self._config_dict['upsample_factor'])
return self._classifier(x) return self._classifier(x)
def get_config(self): def get_config(self):
......
...@@ -95,7 +95,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -95,7 +95,7 @@ class SemanticSegmentationTask(base_task.Task):
parser = segmentation_input.Parser( parser = segmentation_input.Parser(
output_size=params.output_size, output_size=params.output_size,
train_on_crops=params.train_on_crops, crop_size=params.crop_size,
ignore_label=ignore_label, ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth, resize_eval_groundtruth=params.resize_eval_groundtruth,
groundtruth_padded_size=params.groundtruth_padded_size, groundtruth_padded_size=params.groundtruth_padded_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment