Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import sampling_ops from official.projects.centernet.ops import box_list
from official.vision.beta.projects.centernet.ops import box_list from official.vision.ops import sampling_ops
def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from): def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import sampling_ops from official.vision.ops import sampling_ops
def _get_shape(tensor, num_dims): def _get_shape(tensor, num_dims):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.yolo.ops import box_ops from official.projects.yolo.ops import box_ops
NMS_TILE_SIZE = 512 NMS_TILE_SIZE = 512
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing ops imported from OD API."""
import functools
import tensorflow as tf
from official.projects.centernet.ops import box_list
from official.projects.centernet.ops import box_list_ops
def _get_or_create_preprocess_rand_vars(generator_func,
function_id,
preprocess_vars_cache,
key=''):
"""Returns a tensor stored in preprocess_vars_cache or using generator_func.
If the tensor was previously generated and appears in the PreprocessorCache,
the previously generated tensor will be returned. Otherwise, a new tensor
is generated using generator_func and stored in the cache.
Args:
generator_func: A 0-argument function that generates a tensor.
function_id: identifier for the preprocessing function used.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
key: identifier for the variable stored.
Returns:
The generated tensor.
"""
if preprocess_vars_cache is not None:
var = preprocess_vars_cache.get(function_id, key)
if var is None:
var = generator_func()
preprocess_vars_cache.update(function_id, key, var)
else:
var = generator_func()
return var
def _random_integer(minval, maxval, seed):
"""Returns a random 0-D tensor between minval and maxval.
Args:
minval: minimum value of the random tensor.
maxval: maximum value of the random tensor.
seed: random seed.
Returns:
A random 0-D tensor between minval and maxval.
"""
return tf.random.uniform(
[], minval=minval, maxval=maxval, dtype=tf.int32, seed=seed)
def _get_crop_border(border, size):
"""Get the border of cropping."""
border = tf.cast(border, tf.float32)
size = tf.cast(size, tf.float32)
i = tf.math.ceil(tf.math.log(2.0 * border / size) / tf.math.log(2.0))
divisor = tf.pow(2.0, i)
divisor = tf.clip_by_value(divisor, 1, border)
divisor = tf.cast(divisor, tf.int32)
return tf.cast(border, tf.int32) // divisor
def random_square_crop_by_scale(image,
boxes,
labels,
max_border=128,
scale_min=0.6,
scale_max=1.3,
num_scales=8,
seed=None,
preprocess_vars_cache=None):
"""Randomly crop a square in proportion to scale and image size.
Extract a square sized crop from an image whose side length is sampled by
randomly scaling the maximum spatial dimension of the image. If part of
the crop falls outside the image, it is filled with zeros.
The augmentation is borrowed from [1]
[1]: https://arxiv.org/abs/1904.07850
Args:
image: rank 3 float32 tensor containing 1 image ->
[height, width, channels].
boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
Boxes are in normalized form meaning their coordinates vary
between [0, 1]. Each row is in the form of [ymin, xmin, ymax, xmax].
Boxes on the crop boundary are clipped to the boundary and boxes
falling outside the crop are ignored.
labels: rank 1 int32 tensor containing the object classes.
max_border: The maximum size of the border. The border defines distance in
pixels to the image boundaries that will not be considered as a center of
a crop. To make sure that the border does not go over the center of the
image, we chose the border value by computing the minimum k, such that
(max_border / (2**k)) < image_dimension/2.
scale_min: float, the minimum value for scale.
scale_max: float, the maximum value for scale.
num_scales: int, the number of discrete scale values to sample between
[scale_min, scale_max]
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same rank as input image.
boxes: boxes which is the same rank as input boxes.
Boxes are in normalized form.
labels: new labels.
"""
img_shape = tf.shape(image)
height, width = img_shape[0], img_shape[1]
scales = tf.linspace(scale_min, scale_max, num_scales)
scale = _get_or_create_preprocess_rand_vars(
lambda: scales[_random_integer(0, num_scales, seed)],
'square_crop_scale',
preprocess_vars_cache, 'scale')
image_size = scale * tf.cast(tf.maximum(height, width), tf.float32)
image_size = tf.cast(image_size, tf.int32)
h_border = _get_crop_border(max_border, height)
w_border = _get_crop_border(max_border, width)
def y_function():
y = _random_integer(h_border,
tf.cast(height, tf.int32) - h_border + 1,
seed)
return y
def x_function():
x = _random_integer(w_border,
tf.cast(width, tf.int32) - w_border + 1,
seed)
return x
y_center = _get_or_create_preprocess_rand_vars(
y_function,
'square_crop_scale',
preprocess_vars_cache, 'y_center')
x_center = _get_or_create_preprocess_rand_vars(
x_function,
'square_crop_scale',
preprocess_vars_cache, 'x_center')
half_size = tf.cast(image_size / 2, tf.int32)
crop_ymin, crop_ymax = y_center - half_size, y_center + half_size
crop_xmin, crop_xmax = x_center - half_size, x_center + half_size
ymin = tf.maximum(crop_ymin, 0)
xmin = tf.maximum(crop_xmin, 0)
ymax = tf.minimum(crop_ymax, height - 1)
xmax = tf.minimum(crop_xmax, width - 1)
cropped_image = image[ymin:ymax, xmin:xmax]
offset_y = tf.maximum(0, ymin - crop_ymin)
offset_x = tf.maximum(0, xmin - crop_xmin)
oy_i = offset_y
ox_i = offset_x
output_image = tf.image.pad_to_bounding_box(
cropped_image, offset_height=oy_i, offset_width=ox_i,
target_height=image_size, target_width=image_size)
if ymin == 0:
# We might be padding the image.
box_ymin = -offset_y
else:
box_ymin = crop_ymin
if xmin == 0:
# We might be padding the image.
box_xmin = -offset_x
else:
box_xmin = crop_xmin
box_ymax = box_ymin + image_size
box_xmax = box_xmin + image_size
image_box = [box_ymin / height, box_xmin / width,
box_ymax / height, box_xmax / width]
boxlist = box_list.BoxList(boxes)
boxlist = box_list_ops.change_coordinate_frame(boxlist, image_box)
boxlist, indices = box_list_ops.prune_completely_outside_window(
boxlist, [0.0, 0.0, 1.0, 1.0])
boxlist = box_list_ops.clip_to_window(boxlist, [0.0, 0.0, 1.0, 1.0],
filter_nonoverlapping=False)
return_values = [output_image,
boxlist.get(),
tf.gather(labels, indices)]
return return_values
def resize_to_range(image,
masks=None,
min_dimension=None,
max_dimension=None,
method=tf.image.ResizeMethod.BILINEAR,
pad_to_max_dimension=False,
per_channel_pad_value=(0, 0, 0)):
"""Resizes an image so its dimensions are within the provided value.
The output size can be described by two cases:
1. If the image can be rescaled so its minimum dimension is equal to the
provided value without the other dimension exceeding max_dimension,
then do so.
2. Otherwise, resize so the largest dimension is equal to max_dimension.
Args:
image: A 3D tensor of shape [height, width, channels]
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks.
min_dimension: (optional) (scalar) desired size of the smaller image
dimension.
max_dimension: (optional) (scalar) maximum allowed size
of the larger image dimension.
method: (optional) interpolation method used in resizing. Defaults to
BILINEAR.
pad_to_max_dimension: Whether to resize the image and pad it with zeros
so the resulting image is of the spatial size
[max_dimension, max_dimension]. If masks are included they are padded
similarly.
per_channel_pad_value: A tuple of per-channel scalar value to use for
padding. By default pads zeros.
Returns:
Note that the position of the resized_image_shape changes based on whether
masks are present.
resized_image: A 3D tensor of shape [new_height, new_width, channels],
where the image has been resized (with bilinear interpolation) so that
min(new_height, new_width) == min_dimension or
max(new_height, new_width) == max_dimension.
resized_masks: If masks is not None, also outputs masks. A 3D tensor of
shape [num_instances, new_height, new_width].
resized_image_shape: A 1D tensor of shape [3] containing shape of the
resized image.
Raises:
ValueError: if the image is not a 3D tensor.
"""
if len(image.get_shape()) != 3:
raise ValueError('Image should be 3D tensor')
def _resize_landscape_image(image):
# resize a landscape image
return tf.image.resize(
image, tf.stack([min_dimension, max_dimension]), method=method,
preserve_aspect_ratio=True)
def _resize_portrait_image(image):
# resize a portrait image
return tf.image.resize(
image, tf.stack([max_dimension, min_dimension]), method=method,
preserve_aspect_ratio=True)
with tf.name_scope('ResizeToRange'):
if image.get_shape().is_fully_defined():
if image.get_shape()[0] < image.get_shape()[1]:
new_image = _resize_landscape_image(image)
else:
new_image = _resize_portrait_image(image)
new_size = tf.constant(new_image.get_shape().as_list())
else:
new_image = tf.cond(
tf.less(tf.shape(image)[0], tf.shape(image)[1]),
lambda: _resize_landscape_image(image),
lambda: _resize_portrait_image(image))
new_size = tf.shape(new_image)
if pad_to_max_dimension:
channels = tf.unstack(new_image, axis=2)
if len(channels) != len(per_channel_pad_value):
raise ValueError('Number of channels must be equal to the length of '
'per-channel pad value.')
new_image = tf.stack(
[
tf.pad( # pylint: disable=g-complex-comprehension
channels[i], [[0, max_dimension - new_size[0]],
[0, max_dimension - new_size[1]]],
constant_values=per_channel_pad_value[i])
for i in range(len(channels))
],
axis=2)
new_image.set_shape([max_dimension, max_dimension, len(channels)])
result = [new_image, new_size]
if masks is not None:
new_masks = tf.expand_dims(masks, 3)
new_masks = tf.image.resize(
new_masks,
new_size[:-1],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if pad_to_max_dimension:
new_masks = tf.image.pad_to_bounding_box(
new_masks, 0, 0, max_dimension, max_dimension)
new_masks = tf.squeeze(new_masks, 3)
result.append(new_masks)
return result
def _augment_only_rgb_channels(image, augment_function):
"""Augments only the RGB slice of an image with additional channels."""
rgb_slice = image[:, :, :3]
augmented_rgb_slice = augment_function(rgb_slice)
image = tf.concat([augmented_rgb_slice, image[:, :, 3:]], -1)
return image
def random_adjust_brightness(image,
max_delta=0.2,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts brightness.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
max_delta: how much to change the brightness. A value between [0, 1).
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustBrightness'):
generator_func = functools.partial(tf.random.uniform, [],
-max_delta, max_delta, seed=seed)
delta = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_brightness',
preprocess_vars_cache)
def _adjust_brightness(image):
image = tf.image.adjust_brightness(image / 255, delta) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_brightness)
return image
def random_adjust_contrast(image,
min_delta=0.8,
max_delta=1.25,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts contrast.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
min_delta: see max_delta.
max_delta: how much to change the contrast. Contrast will change with a
value between min_delta and max_delta. This value will be
multiplied to the current contrast of the image.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustContrast'):
generator_func = functools.partial(tf.random.uniform, [],
min_delta, max_delta, seed=seed)
contrast_factor = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_contrast',
preprocess_vars_cache)
def _adjust_contrast(image):
image = tf.image.adjust_contrast(image / 255, contrast_factor) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_contrast)
return image
def random_adjust_hue(image,
max_delta=0.02,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts hue.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
max_delta: change hue randomly with a value between 0 and max_delta.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustHue'):
generator_func = functools.partial(tf.random.uniform, [],
-max_delta, max_delta, seed=seed)
delta = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_hue',
preprocess_vars_cache)
def _adjust_hue(image):
image = tf.image.adjust_hue(image / 255, delta) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_hue)
return image
def random_adjust_saturation(image,
min_delta=0.8,
max_delta=1.25,
seed=None,
preprocess_vars_cache=None):
"""Randomly adjusts saturation.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
min_delta: see max_delta.
max_delta: how much to change the saturation. Saturation will change with a
value between min_delta and max_delta. This value will be
multiplied to the current saturation of the image.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with tf.name_scope('RandomAdjustSaturation'):
generator_func = functools.partial(tf.random.uniform, [],
min_delta, max_delta, seed=seed)
saturation_factor = _get_or_create_preprocess_rand_vars(
generator_func,
'adjust_saturation',
preprocess_vars_cache)
def _adjust_saturation(image):
image = tf.image.adjust_saturation(image / 255, saturation_factor) * 255
image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=255.0)
return image
image = _augment_only_rgb_channels(image, _adjust_saturation)
return image
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,7 +18,7 @@ from typing import Dict, List ...@@ -18,7 +18,7 @@ from typing import Dict, List
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import sampling_ops from official.vision.ops import sampling_ops
def smallest_positive_root(a, b, c): def smallest_positive_root(a, b, c):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import preprocess_ops from official.projects.centernet.ops import target_assigner
from official.vision.beta.projects.centernet.ops import target_assigner from official.vision.ops import preprocess_ops
class TargetAssignerTest(tf.test.TestCase, parameterized.TestCase): class TargetAssignerTest(tf.test.TestCase, parameterized.TestCase):
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Centernet task definition."""
from typing import Any, List, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.projects.centernet.configs import centernet as exp_cfg
from official.projects.centernet.dataloaders import centernet_input
from official.projects.centernet.losses import centernet_losses
from official.projects.centernet.modeling import centernet_model
from official.projects.centernet.modeling.heads import centernet_head
from official.projects.centernet.modeling.layers import detection_generator
from official.projects.centernet.ops import loss_ops
from official.projects.centernet.ops import target_assigner
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.modeling.backbones import factory
@task_factory.register_task_cls(exp_cfg.CenterNetTask)
class CenterNetTask(base_task.Task):
"""Task definition for centernet."""
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = centernet_input.CenterNetParser(
output_height=self.task_config.model.input_size[0],
output_width=self.task_config.model.input_size[1],
max_num_instances=self.task_config.model.max_num_instances,
bgr_ordering=params.parser.bgr_ordering,
channel_means=params.parser.channel_means,
channel_stds=params.parser.channel_stds,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
aug_rand_hue=params.parser.aug_rand_hue,
aug_rand_brightness=params.parser.aug_rand_brightness,
aug_rand_contrast=params.parser.aug_rand_contrast,
aug_rand_saturation=params.parser.aug_rand_saturation,
odapi_augmentation=params.parser.odapi_augmentation,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_model(self):
"""get an instance of CenterNet."""
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
l2_weight_decay = self.task_config.weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
backbone = factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=model_config.norm_activation,
l2_regularizer=l2_regularizer)
task_outputs = self.task_config.get_output_length_dict()
head_config = model_config.head
head = centernet_head.CenterNetHead(
input_specs=backbone.output_specs,
task_outputs=task_outputs,
input_levels=head_config.input_levels,
heatmap_bias=head_config.heatmap_bias)
# output_specs is a dict
backbone_output_spec = backbone.output_specs[head_config.input_levels[-1]]
if len(backbone_output_spec) == 4:
bb_output_height = backbone_output_spec[1]
elif len(backbone_output_spec) == 3:
bb_output_height = backbone_output_spec[0]
else:
raise ValueError
self._net_down_scale = int(model_config.input_size[0] / bb_output_height)
dg_config = model_config.detection_generator
detect_generator_obj = detection_generator.CenterNetDetectionGenerator(
max_detections=dg_config.max_detections,
peak_error=dg_config.peak_error,
peak_extract_kernel_size=dg_config.peak_extract_kernel_size,
class_offset=dg_config.class_offset,
net_down_scale=self._net_down_scale,
input_image_dims=model_config.input_size[0],
use_nms=dg_config.use_nms,
nms_pre_thresh=dg_config.nms_pre_thresh,
nms_thresh=dg_config.nms_thresh)
model = centernet_model.CenterNetModel(
backbone=backbone,
head=head,
detection_generator=detect_generator_obj)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
# Restoring checkpoint.
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_losses(self,
outputs,
labels,
aux_losses=None):
"""Build losses."""
input_size = self.task_config.model.input_size[0:2]
output_size = outputs['ct_heatmaps'][0].get_shape().as_list()[1:3]
gt_label = tf.map_fn(
# pylint: disable=g-long-lambda
fn=lambda x: target_assigner.assign_centernet_targets(
labels=x,
input_size=input_size,
output_size=output_size,
num_classes=self.task_config.model.num_classes,
max_num_instances=self.task_config.model.max_num_instances,
gaussian_iou=self.task_config.losses.gaussian_iou,
class_offset=self.task_config.losses.class_offset),
elems=labels,
fn_output_signature={
'ct_heatmaps': tf.TensorSpec(
shape=[output_size[0], output_size[1],
self.task_config.model.num_classes],
dtype=tf.float32),
'ct_offset': tf.TensorSpec(
shape=[self.task_config.model.max_num_instances, 2],
dtype=tf.float32),
'size': tf.TensorSpec(
shape=[self.task_config.model.max_num_instances, 2],
dtype=tf.float32),
'box_mask': tf.TensorSpec(
shape=[self.task_config.model.max_num_instances],
dtype=tf.int32),
'box_indices': tf.TensorSpec(
shape=[self.task_config.model.max_num_instances, 2],
dtype=tf.int32),
}
)
losses = {}
# Create loss functions
object_center_loss_fn = centernet_losses.PenaltyReducedLogisticFocalLoss()
localization_loss_fn = centernet_losses.L1LocalizationLoss()
# Set up box indices so that they have a batch element as well
box_indices = loss_ops.add_batch_to_indices(gt_label['box_indices'])
box_mask = tf.cast(gt_label['box_mask'], dtype=tf.float32)
num_boxes = tf.cast(
loss_ops.get_num_instances_from_weights(gt_label['box_mask']),
dtype=tf.float32)
# Calculate center heatmap loss
output_unpad_image_shapes = tf.math.ceil(
tf.cast(labels['unpad_image_shapes'],
tf.float32) / self._net_down_scale)
valid_anchor_weights = loss_ops.get_valid_anchor_weights_in_flattened_image(
output_unpad_image_shapes, output_size[0], output_size[1])
valid_anchor_weights = tf.expand_dims(valid_anchor_weights, 2)
pred_ct_heatmap_list = outputs['ct_heatmaps']
true_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
gt_label['ct_heatmaps'])
true_flattened_ct_heatmap = tf.cast(true_flattened_ct_heatmap, tf.float32)
total_center_loss = 0.0
for ct_heatmap in pred_ct_heatmap_list:
pred_flattened_ct_heatmap = loss_ops.flatten_spatial_dimensions(
ct_heatmap)
pred_flattened_ct_heatmap = tf.cast(pred_flattened_ct_heatmap, tf.float32)
total_center_loss += object_center_loss_fn(
target_tensor=true_flattened_ct_heatmap,
prediction_tensor=pred_flattened_ct_heatmap,
weights=valid_anchor_weights)
center_loss = tf.reduce_sum(total_center_loss) / float(
len(pred_ct_heatmap_list) * num_boxes)
losses['ct_loss'] = center_loss
# Calculate scale loss
pred_scale_list = outputs['ct_size']
true_scale = tf.cast(gt_label['size'], tf.float32)
total_scale_loss = 0.0
for scale_map in pred_scale_list:
pred_scale = loss_ops.get_batch_predictions_from_indices(scale_map,
box_indices)
pred_scale = tf.cast(pred_scale, tf.float32)
# Only apply loss for boxes that appear in the ground truth
total_scale_loss += tf.reduce_sum(
localization_loss_fn(target_tensor=true_scale,
prediction_tensor=pred_scale),
axis=-1) * box_mask
scale_loss = tf.reduce_sum(total_scale_loss) / float(
len(pred_scale_list) * num_boxes)
losses['scale_loss'] = scale_loss
# Calculate offset loss
pred_offset_list = outputs['ct_offset']
true_offset = tf.cast(gt_label['ct_offset'], tf.float32)
total_offset_loss = 0.0
for offset_map in pred_offset_list:
pred_offset = loss_ops.get_batch_predictions_from_indices(offset_map,
box_indices)
pred_offset = tf.cast(pred_offset, tf.float32)
# Only apply loss for boxes that appear in the ground truth
total_offset_loss += tf.reduce_sum(
localization_loss_fn(target_tensor=true_offset,
prediction_tensor=pred_offset),
axis=-1) * box_mask
offset_loss = tf.reduce_sum(total_offset_loss) / float(
len(pred_offset_list) * num_boxes)
losses['ct_offset_loss'] = offset_loss
# Aggregate and finalize loss
loss_weights = self.task_config.losses.detection
total_loss = (loss_weights.object_center_weight * center_loss +
loss_weights.scale_weight * scale_loss +
loss_weights.offset_weight * offset_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
losses['total_loss'] = total_loss
return losses
def build_metrics(self, training=True):
metrics = []
metric_names = ['total_loss', 'ct_loss', 'scale_loss', 'ct_offset_loss']
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
if not training:
if (self.task_config.validation_data.tfds_name
and self.task_config.annotation_file):
raise ValueError(
"Can't evaluate using annotation file when TFDS is used.")
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file,
include_mask=False,
per_category_metrics=self.task_config.per_category_metrics)
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
losses = self.build_losses(outputs['raw_output'], labels)
scaled_loss = losses['total_loss'] / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
# compute the gradient
tvars = model.trainable_variables
gradients = tape.gradient(scaled_loss, tvars)
# get unscaled loss if the scaled loss was used
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
gradients = optimizer.get_unscaled_gradients(gradients)
if self.task_config.gradient_clip_norm > 0.0:
gradients, _ = tf.clip_by_global_norm(gradients,
self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(gradients, tvars)))
logs = {self.loss: losses['total_loss']}
if metrics:
for m in metrics:
m.update_state(losses[m.name])
logs.update({m.name: m.result()})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = model(features, training=False)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
losses = self.build_losses(outputs['raw_output'], labels)
logs = {self.loss: losses['total_loss']}
coco_model_outputs = {
'detection_boxes': outputs['boxes'],
'detection_scores': outputs['confidence'],
'detection_classes': outputs['classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update({self.coco_metric.name: (labels['groundtruths'],
coco_model_outputs)})
if metrics:
for m in metrics:
m.update_state(losses[m.name])
logs.update({m.name: m.result()})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if state is None:
self.coco_metric.reset_states()
state = self.coco_metric
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
return self.coco_metric.result()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision Centernet trainer."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.projects.centernet.common import registry_imports # pylint: disable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,7 +19,7 @@ from typing import Dict, Optional ...@@ -19,7 +19,7 @@ from typing import Dict, Optional
import numpy as np import numpy as np
from official.vision.beta.projects.centernet.utils.checkpoints import config_classes from official.projects.centernet.utils.checkpoints import config_classes
Conv2DBNCFG = config_classes.Conv2DBNCFG Conv2DBNCFG = config_classes.Conv2DBNCFG
HeadConvCFG = config_classes.HeadConvCFG HeadConvCFG = config_classes.HeadConvCFG
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
"""Functions used to load the ODAPI CenterNet checkpoint.""" """Functions used to load the ODAPI CenterNet checkpoint."""
from official.vision.beta.modeling.backbones import mobilenet from official.projects.centernet.modeling.layers import cn_nn_blocks
from official.vision.beta.modeling.layers import nn_blocks from official.projects.centernet.utils.checkpoints import config_classes
from official.vision.beta.projects.centernet.modeling.layers import cn_nn_blocks from official.projects.centernet.utils.checkpoints import config_data
from official.vision.beta.projects.centernet.utils.checkpoints import config_classes from official.vision.modeling.backbones import mobilenet
from official.vision.beta.projects.centernet.utils.checkpoints import config_data from official.vision.modeling.layers import nn_blocks
Conv2DBNCFG = config_classes.Conv2DBNCFG Conv2DBNCFG = config_classes.Conv2DBNCFG
HeadConvCFG = config_classes.HeadConvCFG HeadConvCFG = config_classes.HeadConvCFG
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,15 +19,15 @@ from absl import flags ...@@ -19,15 +19,15 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling.backbones import factory from official.projects.centernet.common import registry_imports # pylint: disable=unused-import
from official.vision.beta.projects.centernet.common import registry_imports # pylint: disable=unused-import from official.projects.centernet.configs import backbones
from official.vision.beta.projects.centernet.configs import backbones from official.projects.centernet.configs import centernet
from official.vision.beta.projects.centernet.configs import centernet from official.projects.centernet.modeling import centernet_model
from official.vision.beta.projects.centernet.modeling import centernet_model from official.projects.centernet.modeling.heads import centernet_head
from official.vision.beta.projects.centernet.modeling.heads import centernet_head from official.projects.centernet.modeling.layers import detection_generator
from official.vision.beta.projects.centernet.modeling.layers import detection_generator from official.projects.centernet.utils.checkpoints import load_weights
from official.vision.beta.projects.centernet.utils.checkpoints import load_weights from official.projects.centernet.utils.checkpoints import read_checkpoints
from official.vision.beta.projects.centernet.utils.checkpoints import read_checkpoints from official.vision.modeling.backbones import factory
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
# Contextualized Spatial-Temporal Contrastive Learning with Self-Supervision
(WIP) This repository contains the official implementation of
[Contextualized Spatio-Temporal Contrastive Learning with Self-Supervision](https://arxiv.org/abs/2112.05181)
in TF2.
# Crown-of-Thorns Starfish Detection Pipeline
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb?force_crab_mode=1)
This repository shows how to detect crown-of-thorns starfish (COTS) using a
pre-trained COTS detector implemented in TensorFlow.
![Underwater photo of coral reef with annotated boxes identifying detected
starfish](https://storage.googleapis.com/download.tensorflow.org/data/cots_detection/COTS_detected_sample.png)
## Description
Coral reefs are some of the most diverse and important ecosystems in the world,
however they face a number of rising threats that have resulted in massive
global declines. In Australia, outbreaks of the coral-eating crown-of-thorns
starfish (COTS) have been shown to cause major coral loss, with just 15 starfish
in a hectare being able to strip a reef of 90% of its coral tissue. While COTS
naturally exist in the Indo-Pacific, overfishing and excess run-off nutrients
have led to massive outbreaks that are devastating already vulnerable coral
communities.
Controlling COTS populations is critical to promoting coral growth and
resilience, so Google teamed up with Australia’s national science agency,
[CSIRO](https://www.csiro.au/en/), to tackle this problem. We trained ML object
detection models to help scale underwater surveys, enabling the monitoring and
mapping out these harmful invertebrates with the ultimate goal of helping
control teams to address and prioritize outbreaks.
## Get started
[Open the notebook in Colab](https://colab.research.google.com/github/tensorflow/models/blob/master/official/projects/cots_detector/crown_of_thorns_starfish_detection_pipeline.ipynb?force_crab_mode=1)
to run the COTS detection pipeline.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment