Commit 6f917c69 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368036370
parent f281248c
......@@ -31,8 +31,8 @@ task:
losses:
top_k_percent_pixels: 1.0 # only backpropagate loss for the topk 100% pixels.
train_data:
output_size: [512, 1024]
train_on_crops: true
output_size: [1024, 2048]
crop_size: [512, 1024]
input_path: ''
tfds_name: 'cityscapes/semantic_segmentation'
tfds_split: 'train'
......
......@@ -33,9 +33,9 @@ from official.vision.beta.configs import decoders
class DataConfig(cfg.DataConfig):
"""Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
# If train_on_crops is set to True, a patch of size output_size is cropped
# from the input image.
train_on_crops: bool = False
# If crop_size is specified, image will be resized first to
# output_size, then crop of size crop_size will be cropped.
crop_size: List[int] = dataclasses.field(default_factory=list)
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
......@@ -56,9 +56,11 @@ class DataConfig(cfg.DataConfig):
@dataclasses.dataclass
class SegmentationHead(hyperparams.Config):
"""Segmentation head config."""
level: int = 3
num_convs: int = 2
num_filters: int = 256
prediction_kernel_size: int = 1
upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
# deeplabv3plus feature fusion params
......@@ -433,8 +435,8 @@ def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig:
train_data=DataConfig(
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE,
'train_fine**'),
output_size=[512, 1024],
train_on_crops=True,
crop_size=[512, 1024],
output_size=[1024, 2048],
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
......
......@@ -43,7 +43,7 @@ class Parser(parser.Parser):
def __init__(self,
output_size,
train_on_crops=False,
crop_size=None,
resize_eval_groundtruth=True,
groundtruth_padded_size=None,
ignore_label=255,
......@@ -56,9 +56,10 @@ class Parser(parser.Parser):
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
train_on_crops: `bool`, if True, a training crop of size output_size
is returned. This is useful for cropping original images during training
while evaluating on original image sizes.
crop_size: `Tensor` or `list` for [height, width] of the crop. If
specified a training crop of size crop_size is returned. This is useful
for cropping original images during training while evaluating on
original image sizes.
resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
resized to output_size.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
......@@ -75,7 +76,7 @@ class Parser(parser.Parser):
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
self._output_size = output_size
self._train_on_crops = train_on_crops
self._crop_size = crop_size
self._resize_eval_groundtruth = resize_eval_groundtruth
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
......@@ -110,24 +111,32 @@ class Parser(parser.Parser):
"""Parses data for training and evaluation."""
image, label = self._prepare_image_and_label(data)
if self._train_on_crops:
if self._crop_size:
label = tf.reshape(label, [data['image/height'], data['image/width'], 1])
# If output_size is specified, resize image, and label to desired
# output_size.
if self._output_size:
image = tf.image.resize(image, self._output_size, method='bilinear')
label = tf.image.resize(label, self._output_size, method='nearest')
image_mask = tf.concat([image, label], axis=2)
image_mask_crop = tf.image.random_crop(image_mask,
self._output_size + [4])
self._crop_size + [4])
image = image_mask_crop[:, :, :-1]
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._output_size)
label = tf.reshape(image_mask_crop[:, :, -1], [1] + self._crop_size)
# Flips image randomly during training.
if self._aug_rand_hflip:
image, _, label = preprocess_ops.random_horizontal_flip(
image, masks=label)
train_image_size = self._crop_size if self._crop_size else self._output_size
# Resizes and crops image.
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._output_size,
self._output_size,
train_image_size,
train_image_size,
aug_scale_min=self._aug_scale_min,
aug_scale_max=self._aug_scale_max)
......@@ -140,7 +149,7 @@ class Parser(parser.Parser):
label += 1
label = tf.expand_dims(label, axis=3)
label = preprocess_ops.resize_and_crop_masks(
label, image_scale, self._output_size, offset)
label, image_scale, train_image_size, offset)
label -= 1
label = tf.where(tf.equal(label, -1),
self._ignore_label * tf.ones_like(label), label)
......
......@@ -272,6 +272,7 @@ def build_segmentation_model(
num_classes=model_config.num_classes,
level=head_config.level,
num_convs=head_config.num_convs,
prediction_kernel_size=head_config.prediction_kernel_size,
num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor,
feature_fusion=head_config.feature_fusion,
......
......@@ -30,6 +30,7 @@ class SegmentationHead(tf.keras.layers.Layer):
level,
num_convs=2,
num_filters=256,
prediction_kernel_size=1,
upsample_factor=1,
feature_fusion=None,
low_level=2,
......@@ -51,6 +52,8 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer.
num_filters: An `int` number to specify the number of filters used.
Default is 256.
prediction_kernel_size: An `int` number to specify the kernel size of the
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If
......@@ -80,6 +83,7 @@ class SegmentationHead(tf.keras.layers.Layer):
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion,
'low_level': low_level,
......@@ -146,7 +150,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self._classifier = conv_op(
name='segmentation_output',
filters=self._config_dict['num_classes'],
kernel_size=1,
kernel_size=self._config_dict['prediction_kernel_size'],
padding='same',
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
......@@ -193,8 +197,10 @@ class SegmentationHead(tf.keras.layers.Layer):
x = conv(x)
x = norm(x)
x = self._activation(x)
if self._config_dict['upsample_factor'] > 1:
x = spatial_transform_ops.nearest_upsampling(
x, scale=self._config_dict['upsample_factor'])
return self._classifier(x)
def get_config(self):
......
......@@ -95,7 +95,7 @@ class SemanticSegmentationTask(base_task.Task):
parser = segmentation_input.Parser(
output_size=params.output_size,
train_on_crops=params.train_on_crops,
crop_size=params.crop_size,
ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth,
groundtruth_padded_size=params.groundtruth_padded_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment