Commit 03ae8d2d authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 338789107
parent 4e3e5fab
...@@ -18,4 +18,5 @@ ...@@ -18,4 +18,5 @@
from official.vision.beta.configs import image_classification from official.vision.beta.configs import image_classification
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import retinanet from official.vision.beta.configs import retinanet
from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.configs import video_classification from official.vision.beta.configs import video_classification
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 8
norm_activation:
activation: 'swish'
losses:
l2_weight_decay: 0.0001
one_hot: true
label_smoothing: 0.1
train_data:
input_path: 'imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 4096
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 62400
validation_steps: 13
validation_interval: 312
steps_per_loop: 312
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 1.6
decay_steps: 62400
warmup:
type: 'linear'
linear:
warmup_steps: 1560
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 8
norm_activation:
activation: 'swish'
losses:
l2_weight_decay: 0.0001
one_hot: true
label_smoothing: 0.1
train_data:
input_path: 'imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 4096
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 62400
validation_steps: 13
validation_interval: 312
steps_per_loop: 312
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 1.6
decay_steps: 62400
warmup:
type: 'linear'
linear:
warmup_steps: 1560
# Dilated ResNet-50 Pascal segmentation. 80.89 mean IOU.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 8
norm_activation:
activation: 'swish'
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32'
task:
model:
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 50
output_stride: 8
norm_activation:
activation: 'swish'
init_checkpoint: 'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400'
init_checkpoint_modules: 'backbone'
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image segmentation configuration definition."""
import os
from typing import List, Union, Optional
import dataclasses
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.configs import decoders
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 1000
cycle_length: int = 10
resize_eval_groundtruth: bool = True
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
@dataclasses.dataclass
class SegmentationHead(hyperparams.Config):
level: int = 3
num_convs: int = 2
num_filters: int = 256
upsample_factor: int = 1
@dataclasses.dataclass
class ImageSegmentationModel(hyperparams.Config):
"""Image segmentation model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
min_level: int = 3
max_level: int = 6
head: SegmentationHead = SegmentationHead()
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
decoder: decoders.Decoder = decoders.Decoder(type='identity')
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.1
ignore_label: int = 255
class_weights: List[float] = dataclasses.field(default_factory=list)
l2_weight_decay: float = 0.0
use_groundtruth_dimension: bool = True
@dataclasses.dataclass
class ImageSegmentationTask(cfg.TaskConfig):
"""The model config."""
model: ImageSegmentationModel = ImageSegmentationModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
@exp_factory.register_config_factory('semantic_segmentation')
def semantic_segmentation() -> cfg.ExperimentConfig:
"""Semantic segmentation general."""
return cfg.ExperimentConfig(
task=ImageSegmentationModel(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
# PASCAL VOC 2012 Dataset
PASCAL_TRAIN_EXAMPLES = 10582
PASCAL_VAL_EXAMPLES = 1449
PASCAL_INPUT_PATH_BASE = 'pascal_voc_seg'
@exp_factory.register_config_factory('seg_deeplabv3_pascal')
def seg_deeplabv3_pascal() -> cfg.ExperimentConfig:
"""Image segmentation on imagenet with resnet deeplabv3."""
train_batch_size = 16
eval_batch_size = 8
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageSegmentationTask(
model=ImageSegmentationModel(
num_classes=21,
# TODO(arashwan): test changing size to 513 to match deeplab.
input_size=[512, 512, 3],
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50, output_stride=8)),
decoder=decoders.Decoder(
type='aspp', aspp=decoders.ASPP(
level=3, dilation_rates=[12, 24, 36])),
head=SegmentationHead(level=3, num_convs=0),
norm_activation=common.NormActivation(
activation='swish',
norm_momentum=0.9997,
norm_epsilon=1e-3,
use_sync_bn=True)),
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'),
is_training=True,
global_batch_size=train_batch_size,
aug_scale_min=0.5,
aug_scale_max=2.0),
validation_data=DataConfig(
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
resize_eval_groundtruth=False,
groundtruth_padded_size=[512, 512]),
# resnet50
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=45 * steps_per_epoch,
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.007,
'decay_steps': 45 * steps_per_epoch,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for semantic_segmentation."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.vision import beta
from official.vision.beta.configs import semantic_segmentation as exp_cfg
class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('seg_deeplabv3_pascal',),)
def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.ImageSegmentationTask)
self.assertIsInstance(config.task.model,
exp_cfg.ImageSegmentationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
...@@ -42,7 +42,8 @@ class Parser(parser.Parser): ...@@ -42,7 +42,8 @@ class Parser(parser.Parser):
def __init__(self, def __init__(self,
output_size, output_size,
resize_eval=False, resize_eval_groundtruth=True,
groundtruth_padded_size=None,
ignore_label=255, ignore_label=255,
aug_rand_hflip=False, aug_rand_hflip=False,
aug_scale_min=1.0, aug_scale_min=1.0,
...@@ -53,8 +54,11 @@ class Parser(parser.Parser): ...@@ -53,8 +54,11 @@ class Parser(parser.Parser):
Args: Args:
output_size: `Tensor` or `list` for [height, width] of output image. The output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level. output_size should be divided by the largest feature stride 2^max_level.
resize_eval: 'bool', if True, during evaluation, the max side of image and resize_eval_groundtruth: `bool`, if True, eval groundtruth masks are
label will be resized to output_size, otherwise image will be padded. resized to output_size.
groundtruth_padded_size: `Tensor` or `list` for [height, width]. When
resize_eval_groundtruth is set to False, the groundtruth masks are
padded to this size.
ignore_label: `int` the pixel with ignore label will not used for training ignore_label: `int` the pixel with ignore label will not used for training
and evaluation. and evaluation.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
...@@ -66,7 +70,11 @@ class Parser(parser.Parser): ...@@ -66,7 +70,11 @@ class Parser(parser.Parser):
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}. dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
""" """
self._output_size = output_size self._output_size = output_size
self._resize_eval = resize_eval self._resize_eval_groundtruth = resize_eval_groundtruth
if (not resize_eval_groundtruth) and (groundtruth_padded_size is None):
raise ValueError('groundtruth_padded_size ([height, width]) needs to be'
'specified when resize_eval_groundtruth is False.')
self._groundtruth_padded_size = groundtruth_padded_size
self._ignore_label = ignore_label self._ignore_label = ignore_label
# Data augmentation. # Data augmentation.
...@@ -125,7 +133,8 @@ class Parser(parser.Parser): ...@@ -125,7 +133,8 @@ class Parser(parser.Parser):
valid_mask = tf.not_equal(label, self._ignore_label) valid_mask = tf.not_equal(label, self._ignore_label)
labels = { labels = {
'masks': label, 'masks': label,
'valid_masks': valid_mask 'valid_masks': valid_mask,
'image_info': image_info,
} }
# Cast image as self._dtype # Cast image as self._dtype
...@@ -140,23 +149,21 @@ class Parser(parser.Parser): ...@@ -140,23 +149,21 @@ class Parser(parser.Parser):
label += 1 label += 1
label = tf.expand_dims(label, axis=3) label = tf.expand_dims(label, axis=3)
if self._resize_eval: # Resizes and crops image.
# Resizes and crops image. image, image_info = preprocess_ops.resize_and_crop_image(
image, image_info = preprocess_ops.resize_and_crop_image( image, self._output_size, self._output_size)
image, self._output_size, self._output_size)
# Resizes and crops mask. if self._resize_eval_groundtruth:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_scale = image_info[2, :] image_scale = image_info[2, :]
offset = image_info[3, :] offset = image_info[3, :]
label = preprocess_ops.resize_and_crop_masks(label, image_scale, label = preprocess_ops.resize_and_crop_masks(label, image_scale,
self._output_size, offset) self._output_size, offset)
else: else:
# Pads image and mask to output size. label = tf.image.pad_to_bounding_box(
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0], label, 0, 0, self._groundtruth_padded_size[0],
self._output_size[1]) self._groundtruth_padded_size[1])
label = tf.image.pad_to_bounding_box(label, 0, 0, self._output_size[0],
self._output_size[1])
label -= 1 label -= 1
label = tf.where(tf.equal(label, -1), label = tf.where(tf.equal(label, -1),
...@@ -166,7 +173,8 @@ class Parser(parser.Parser): ...@@ -166,7 +173,8 @@ class Parser(parser.Parser):
valid_mask = tf.not_equal(label, self._ignore_label) valid_mask = tf.not_equal(label, self._ignore_label)
labels = { labels = {
'masks': label, 'masks': label,
'valid_masks': valid_mask 'valid_masks': valid_mask,
'image_info': image_info
} }
# Cast image as self._dtype # Cast image as self._dtype
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Metrics for segmentation."""
import tensorflow as tf
class MeanIoU(tf.keras.metrics.MeanIoU):
"""Mean IoU metric for semantic segmentation.
This class utilizes tf.keras.metrics.MeanIoU to perform batched mean iou when
both input images and groundtruth masks are resized to the same size
(rescale_predictions=False). It also computes mean iou on groundtruth original
sizes, in which case, each prediction is rescaled back to the original image
size.
"""
def __init__(
self, num_classes, rescale_predictions=False, name=None, dtype=None):
"""Constructs Segmentation evaluator class.
Args:
num_classes: `int`, number of classes.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, y_true['image_info'] is used to rescale
predictions.
name: `str`, name of the metric instance..
dtype: data type of the metric result.
"""
self._rescale_predictions = rescale_predictions
super(MeanIoU, self).__init__(
num_classes=num_classes, name=name, dtype=dtype)
def update_state(self, y_true, y_pred):
"""Updates metic state.
Args:
y_true: `dict`, dictionary with the following name, and key values.
- masks: [batch, width, height, 1], groundtruth masks.
- valid_masks: [batch, width, height, 1], valid elements in the mask.
- image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width],
[y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension /
original dimension.
y_pred: Tensor [batch, width_p, height_p, num_classes], predicated masks.
"""
predictions = y_pred
masks = y_true['masks']
valid_masks = y_true['valid_masks']
images_info = y_true['image_info']
if isinstance(predictions, tuple) or isinstance(predictions, list):
predictions = tf.concat(predictions, axis=0)
masks = tf.concat(masks, axis=0)
valid_masks = tf.concat(valid_masks, axis=0)
images_info = tf.concat(images_info, axis=0)
# Ignore mask elements is set to zero for argmax op.
masks = tf.where(valid_masks, masks, tf.zeros_like(masks))
if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = []
flatten_masks = []
flatten_valid_masks = []
for mask, valid_mask, predicted_mask, image_info in zip(
masks, valid_masks, predictions, images_info):
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
predicted_mask = tf.image.resize(
predicted_mask,
rescale_size,
method=tf.image.ResizeMethod.BILINEAR)
predicted_mask = tf.image.crop_to_bounding_box(predicted_mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.crop_to_bounding_box(mask, 0, 0, image_shape[0],
image_shape[1])
valid_mask = tf.image.crop_to_bounding_box(valid_mask, 0, 0,
image_shape[0],
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1]))
flatten_masks.append(tf.reshape(mask, shape=[1, -1]))
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1]))
flatten_predictions = tf.concat(flatten_predictions, axis=1)
flatten_masks = tf.concat(flatten_masks, axis=1)
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1)
else:
predictions = tf.image.resize(
predictions,
tf.shape(masks)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
predictions = tf.argmax(predictions, axis=3)
flatten_predictions = tf.reshape(predictions, shape=[-1])
flatten_masks = tf.reshape(masks, shape=[-1])
flatten_valid_masks = tf.reshape(valid_masks, shape=[-1])
super(MeanIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses used for segmentation models."""
# Import libraries
import tensorflow as tf
EPSILON = 1e-5
class SegmentationLoss:
"""Semantic segmentation loss."""
def __init__(self, label_smoothing, class_weights,
ignore_label, use_groundtruth_dimension):
self._class_weights = class_weights
self._ignore_label = ignore_label
self._use_groundtruth_dimension = use_groundtruth_dimension
self._label_smoothing = label_smoothing
def __call__(self, logits, labels):
_, height, width, num_classes = logits.get_shape().as_list()
if self._use_groundtruth_dimension:
# TODO(arashwan): Test using align corners to match deeplab alignment.
logits = tf.image.resize(
logits, tf.shape(labels)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
else:
labels = tf.image.resize(
labels, (height, width),
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
valid_mask = tf.not_equal(labels, self._ignore_label)
normalizer = tf.reduce_sum(tf.cast(valid_mask, tf.float32)) + EPSILON
# Assign pixel with ignore label to class 0 (background). The loss on the
# pixel will later be masked out.
labels = tf.where(valid_mask, labels, tf.zeros_like(labels))
labels = tf.squeeze(tf.cast(labels, tf.int32), axis=3)
valid_mask = tf.squeeze(tf.cast(valid_mask, tf.float32), axis=3)
onehot_labels = tf.one_hot(labels, num_classes)
onehot_labels = onehot_labels * (
1 - self._label_smoothing) + self._label_smoothing / num_classes
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits)
if not self._class_weights:
class_weights = [1] * num_classes
else:
class_weights = self._class_weights
if num_classes != len(class_weights):
raise ValueError(
'Length of class_weights should be {}'.format(num_classes))
weight_mask = tf.einsum('...y,y->...',
tf.one_hot(labels, num_classes, dtype=tf.float32),
tf.constant(class_weights, tf.float32))
valid_mask *= weight_mask
cross_entropy_loss *= tf.cast(valid_mask, tf.float32)
loss = tf.reduce_sum(cross_entropy_loss) / normalizer
return loss
...@@ -15,18 +15,22 @@ ...@@ -15,18 +15,22 @@
"""Factory methods to build models.""" """Factory methods to build models."""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.configs import image_classification as classification_cfg from official.vision.beta.configs import image_classification as classification_cfg
from official.vision.beta.configs import maskrcnn as maskrcnn_cfg from official.vision.beta.configs import maskrcnn as maskrcnn_cfg
from official.vision.beta.configs import retinanet as retinanet_cfg from official.vision.beta.configs import retinanet as retinanet_cfg
from official.vision.beta.configs import semantic_segmentation as segmentation_cfg
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import classification_model from official.vision.beta.modeling import classification_model
from official.vision.beta.modeling import maskrcnn_model from official.vision.beta.modeling import maskrcnn_model
from official.vision.beta.modeling import retinanet_model from official.vision.beta.modeling import retinanet_model
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.modeling.decoders import factory as decoder_factory from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import dense_prediction_heads from official.vision.beta.modeling.heads import dense_prediction_heads
from official.vision.beta.modeling.heads import instance_heads from official.vision.beta.modeling.heads import instance_heads
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.modeling.layers import detection_generator from official.vision.beta.modeling.layers import detection_generator
from official.vision.beta.modeling.layers import mask_sampler from official.vision.beta.modeling.layers import mask_sampler
from official.vision.beta.modeling.layers import roi_aligner from official.vision.beta.modeling.layers import roi_aligner
...@@ -233,3 +237,37 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec, ...@@ -233,3 +237,37 @@ def build_retinanet(input_specs: tf.keras.layers.InputSpec,
model = retinanet_model.RetinaNetModel( model = retinanet_model.RetinaNetModel(
backbone, decoder, head, detection_generator_obj) backbone, decoder, head, detection_generator_obj)
return model return model
def build_segmentation_model(
input_specs: tf.keras.layers.InputSpec,
model_config: segmentation_cfg.ImageSegmentationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
"""Builds Segmentation model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
head_config = model_config.head
norm_activation_config = model_config.norm_activation
head = segmentation_heads.SegmentationHead(
num_classes=model_config.num_classes,
level=head_config.level,
num_convs=head_config.num_convs,
num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = segmentation_model.SegmentationModel(backbone, decoder, head)
return model
...@@ -18,4 +18,5 @@ ...@@ -18,4 +18,5 @@
from official.vision.beta.tasks import image_classification from official.vision.beta.tasks import image_classification
from official.vision.beta.tasks import maskrcnn from official.vision.beta.tasks import maskrcnn
from official.vision.beta.tasks import retinanet from official.vision.beta.tasks import retinanet
from official.vision.beta.tasks import semantic_segmentation
from official.vision.beta.tasks import video_classification from official.vision.beta.tasks import video_classification
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image segmentation task definition."""
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.configs import semantic_segmentation as exp_cfg
from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.ImageSegmentationTask)
class ImageSegmentationTask(base_task.Task):
"""A task for image classification."""
def build_model(self):
"""Builds classification model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_segmentation_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
"""Builds classification input."""
input_size = self.task_config.model.input_size
ignore_label = self.task_config.losses.ignore_label
decoder = segmentation_input.Decoder()
parser = segmentation_input.Parser(
output_size=input_size[:2],
ignore_label=ignore_label,
resize_eval_groundtruth=params.resize_eval_groundtruth,
groundtruth_padded_size=params.groundtruth_padded_size,
aug_scale_min=params.aug_scale_min,
aug_scale_max=params.aug_scale_max,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
loss_params = self._task_config.losses
segmentation_loss_fn = segmentation_losses.SegmentationLoss(
loss_params.label_smoothing,
loss_params.class_weights,
loss_params.ignore_label,
use_groundtruth_dimension=loss_params.use_groundtruth_dimension)
total_loss = segmentation_loss_fn(model_outputs, labels['masks'])
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
metrics = []
if training:
# TODO(arashwan): make MeanIoU tpu friendly.
if not isinstance(tf.distribute.get_strategy(),
tf.distribute.experimental.TPUStrategy):
metrics.append(segmentation_metrics.MeanIoU(
name='mean_iou',
num_classes=self.task_config.model.num_classes,
rescale_predictions=False))
else:
self.miou_metric = segmentation_metrics.MeanIoU(
name='val_mean_iou',
num_classes=self.task_config.model.num_classes,
rescale_predictions=not self.task_config.validation_data
.resize_eval_groundtruth)
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
if self.task_config.gradient_clip_norm > 0:
grads, _ = tf.clip_by_global_norm(
grads, self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
logs.update({self.miou_metric.name: (labels, outputs)})
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.miou_metric.reset_states()
state = self.miou_metric
self.miou_metric.update_state(step_outputs[self.miou_metric.name][0],
step_outputs[self.miou_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs):
return {self.miou_metric.name: self.miou_metric.result().numpy()}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment