Commit 2e9bb539 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 7bae5317 8fba84f8
# ResNet-350 ImageNet classification. 84.2% top-1 accuracy.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -9,14 +8,17 @@ task:
backbone:
type: 'resnet'
resnet:
model_id: 350
depth_multiplier: 1.25
stem_type: 'v1'
model_id: 420
replace_stem_max_pool: true
resnetd_shortcut: true
se_ratio: 0.25
stochastic_depth_drop_rate: 0.2
stem_type: 'v1'
stochastic_depth_drop_rate: 0.1
norm_activation:
activation: 'swish'
dropout_rate: 0.5
norm_momentum: 0.0
use_sync_bn: false
dropout_rate: 0.4
losses:
l2_weight_decay: 0.00004
one_hot: true
......@@ -27,6 +29,7 @@ task:
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......@@ -41,6 +44,8 @@ trainer:
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
ema:
average_decay: 0.9999
optimizer:
type: 'sgd'
sgd:
......@@ -53,4 +58,4 @@ trainer:
warmup:
type: 'linear'
linear:
warmup_steps: 5000
warmup_steps: 1560
# ResNet-RS-50 ImageNet classification. 79.1% top-5 accuracy.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 1001
input_size: [160, 160, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
replace_stem_max_pool: true
resnetd_shortcut: true
se_ratio: 0.25
stem_type: 'v1'
stochastic_depth_drop_rate: 0.0
norm_activation:
activation: 'swish'
norm_momentum: 0.0
use_sync_bn: false
dropout_rate: 0.25
losses:
l2_weight_decay: 0.00004
one_hot: true
label_smoothing: 0.1
train_data:
input_path: 'imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 10
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 4096
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 109200
validation_steps: 13
validation_interval: 312
steps_per_loop: 312
summary_interval: 312
checkpoint_interval: 312
optimizer_config:
ema:
average_decay: 0.9999
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 1.6
decay_steps: 109200
warmup:
type: 'linear'
linear:
warmup_steps: 1560
......@@ -35,6 +35,7 @@ class DataConfig(cfg.DataConfig):
shuffle_buffer_size: int = 10000
cycle_length: int = 10
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'
randaug_magnitude: Optional[int] = 10
file_type: str = 'tfrecord'
......@@ -162,6 +163,82 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
return config
@exp_factory.register_config_factory('resnet_rs_imagenet')
def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
"""Image classification on imagenet with resnet-rs."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[160, 160, 3],
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(
model_id=50,
stem_type='v1',
resnetd_shortcut=True,
replace_stem_max_pool=True,
se_ratio=0.25,
stochastic_depth_drop_rate=0.0)),
dropout_rate=0.25,
norm_activation=common.NormActivation(
norm_momentum=0.0,
norm_epsilon=1e-5,
use_sync_bn=False,
activation='swish')),
losses=Losses(l2_weight_decay=4e-5, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_policy='randaug',
randaug_magnitude=10),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=350 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9
}
},
'ema': {
'average_decay': 0.9999
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 1.6,
'decay_steps': 350 * steps_per_epoch
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('revnet_imagenet')
def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
"""Returns a revnet config for image classification on imagenet."""
......
......@@ -26,9 +26,12 @@ from official.vision.beta.configs import image_classification as exp_cfg
class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('resnet_imagenet',),
('revnet_imagenet',),
('mobilenet_imagenet'),)
@parameterized.parameters(
('resnet_imagenet',),
('resnet_rs_imagenet',),
('revnet_imagenet',),
('mobilenet_imagenet'),
)
def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
......
......@@ -90,6 +90,12 @@ class Losses(hyperparams.Config):
top_k_percent_pixels: float = 1.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
report_per_class_iou: bool = True
report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass
class SemanticSegmentationTask(cfg.TaskConfig):
"""The model config."""
......@@ -97,6 +103,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
evaluation: Evaluation = Evaluation()
train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field(
......
......@@ -49,6 +49,7 @@ class Parser(parser.Parser):
num_classes: float,
aug_rand_hflip: bool = True,
aug_policy: Optional[str] = None,
randaug_magnitude: Optional[int] = 10,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -59,6 +60,7 @@ class Parser(parser.Parser):
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
randaug_magnitude: `int`, magnitude of the randaugment policy.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
......@@ -77,7 +79,8 @@ class Parser(parser.Parser):
if aug_policy == 'autoaug':
self._augmenter = augment.AutoAugment()
elif aug_policy == 'randaug':
self._augmenter = augment.RandAugment(num_layers=2, magnitude=20)
self._augmenter = augment.RandAugment(
num_layers=2, magnitude=randaug_magnitude)
else:
raise ValueError(
'Augmentation policy {} not supported.'.format(aug_policy))
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset reader for vision model garden."""
from typing import Any, Callable, Optional, Tuple
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
def calculate_batch_sizes(total_batch_size: int,
pseudo_label_ratio: float) -> Tuple[int, int]:
"""Calculates labeled and pseudo-labeled dataset batch sizes.
Returns (labeled_batch_size, pseudo_labeled_batch_size) given a
total batch size and pseudo-label data ratio.
Args:
total_batch_size: The total batch size for all data.
pseudo_label_ratio: A non-negative float ratio of pseudo-labeled
to labeled data in a batch.
Returns:
(labeled_batch_size, pseudo_labeled_batch_size) as ints.
Raises:
ValueError: If total_batch_size is negative.
ValueError: If pseudo_label_ratio is negative.
"""
if total_batch_size < 0:
raise ValueError('Invalid total_batch_size: {}'.format(total_batch_size))
if pseudo_label_ratio < 0.0:
raise ValueError(
'Invalid pseudo_label_ratio: {}'.format(pseudo_label_ratio))
ratio_factor = pseudo_label_ratio / (1.0 + pseudo_label_ratio)
pseudo_labeled_batch_size = int(round(total_batch_size * ratio_factor))
labeled_batch_size = total_batch_size - pseudo_labeled_batch_size
return labeled_batch_size, pseudo_labeled_batch_size
class CombinationDatasetInputReader(input_reader.InputReader):
"""Combination dataset input reader."""
def __init__(self,
params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset,
pseudo_label_dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an CombinationDatasetInputReader instance.
This class mixes a labeled and pseudo-labeled dataset. The params
must contain "pseudo_label_data.input_path" to specify the
pseudo-label dataset files and "pseudo_label_data.data_ratio"
to specify a per-batch mixing ratio of pseudo-label examples to
labeled dataset examples.
Args:
params: A config_definitions.DataConfig object.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
pseudo_label_dataset_fn: A `tf.data.Dataset` that consumes the input
files. For example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be executed after
`parser_fn` to transform and batch the dataset; if None, after
`parser_fn` is executed, the dataset will be batched into per-replica
batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
Raises:
ValueError: If drop_remainder is False.
"""
super().__init__(params=params,
dataset_fn=dataset_fn,
decoder_fn=decoder_fn,
sample_fn=sample_fn,
parser_fn=parser_fn,
transform_and_batch_fn=transform_and_batch_fn,
postprocess_fn=postprocess_fn)
self._pseudo_label_file_pattern = params.pseudo_label_data.input_path
self._pseudo_label_dataset_fn = pseudo_label_dataset_fn
self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio
self._pseudo_label_matched_files = self._match_files(
self._pseudo_label_file_pattern)
if not self._drop_remainder:
raise ValueError(
'Must use drop_remainder=True with CombinationDatasetInputReader')
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
labeled_batch_size, pl_batch_size = calculate_batch_sizes(
self._global_batch_size, self._pseudo_label_data_ratio)
if not labeled_batch_size and pl_batch_size:
raise ValueError(
'Invalid batch_size: {} and pseudo_label_data_ratio: {}, '
'resulting in a 0 batch size for one of the datasets.'.format(
self._global_batch_size, self._pseudo_label_data_ratio))
labeled_dataset = self._read_decode_and_parse_dataset(
matched_files=self._matched_files,
dataset_fn=self._dataset_fn,
batch_size=labeled_batch_size,
input_context=input_context,
tfds_builder=self._tfds_builder)
pseudo_labeled_dataset = self._read_decode_and_parse_dataset(
matched_files=self._pseudo_label_matched_files,
dataset_fn=self._pseudo_label_dataset_fn,
batch_size=pl_batch_size,
input_context=input_context,
tfds_builder=False)
def concat_fn(d1, d2):
return tf.nest.map_structure(
lambda x1, x2: tf.concat([x1, x2], axis=0), d1, d2)
dataset_concat = tf.data.Dataset.zip(
(labeled_dataset, pseudo_labeled_dataset))
dataset_concat = dataset_concat.map(
concat_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def maybe_map_fn(dataset, fn):
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset_concat = maybe_map_fn(dataset_concat, self._postprocess_fn)
dataset_concat = self._maybe_apply_data_service(dataset_concat,
input_context)
if self._deterministic is not None:
options = tf.data.Options()
options.experimental_deterministic = self._deterministic
dataset_concat = dataset_concat.with_options(options)
return dataset_concat.prefetch(tf.data.experimental.AUTOTUNE)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Factory for getting TF-Vision input readers."""
from official.common import dataset_fn as dataset_fn_util
from official.core import config_definitions as cfg
from official.core import input_reader as core_input_reader
from official.vision.beta.dataloaders import input_reader as vision_input_reader
def input_reader_generator(params: cfg.DataConfig,
**kwargs) -> core_input_reader.InputReader:
"""Instantiates an input reader class according to the params.
Args:
params: A config_definitions.DataConfig object.
**kwargs: Additional arguments passed to input reader initialization.
Returns:
An InputReader object.
"""
if params.is_training and params.get('pseudo_label_data', False):
return vision_input_reader.CombinationDatasetInputReader(
params,
pseudo_label_dataset_fn=dataset_fn_util.pick_dataset_fn(
params.pseudo_label_data.file_type),
**kwargs)
else:
return core_input_reader.InputReader(params, **kwargs)
......@@ -16,6 +16,8 @@
import tensorflow as tf
from official.vision import keras_cv
class MeanIoU(tf.keras.metrics.MeanIoU):
"""Mean IoU metric for semantic segmentation.
......@@ -122,3 +124,110 @@ class MeanIoU(tf.keras.metrics.MeanIoU):
super(MeanIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
class PerClassIoU(keras_cv.metrics.PerClassIoU):
"""Per Class IoU metric for semantic segmentation.
This class utilizes keras_cv.metrics.PerClassIoU to perform batched per class
iou when both input images and groundtruth masks are resized to the same size
(rescale_predictions=False). It also computes per class iou on groundtruth
original sizes, in which case, each prediction is rescaled back to the
original image size.
"""
def __init__(
self, num_classes, rescale_predictions=False, name=None, dtype=None):
"""Constructs Segmentation evaluator class.
Args:
num_classes: `int`, number of classes.
rescale_predictions: `bool`, whether to scale back prediction to original
image sizes. If True, y_true['image_info'] is used to rescale
predictions.
name: `str`, name of the metric instance..
dtype: data type of the metric result.
"""
self._rescale_predictions = rescale_predictions
super(PerClassIoU, self).__init__(
num_classes=num_classes, name=name, dtype=dtype)
def update_state(self, y_true, y_pred):
"""Updates metric state.
Args:
y_true: `dict`, dictionary with the following name, and key values.
- masks: [batch, width, height, 1], groundtruth masks.
- valid_masks: [batch, width, height, 1], valid elements in the mask.
- image_info: [batch, 4, 2], a tensor that holds information about
original and preprocessed images. Each entry is in the format of
[[original_height, original_width], [input_height, input_width],
[y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
desired_width] is the actual scaled image size, and [y_scale, x_scale]
is the scaling factor, which is the ratio of scaled dimension /
original dimension.
y_pred: Tensor [batch, width_p, height_p, num_classes], predicated masks.
"""
predictions = y_pred
masks = y_true['masks']
valid_masks = y_true['valid_masks']
images_info = y_true['image_info']
if isinstance(predictions, tuple) or isinstance(predictions, list):
predictions = tf.concat(predictions, axis=0)
masks = tf.concat(masks, axis=0)
valid_masks = tf.concat(valid_masks, axis=0)
images_info = tf.concat(images_info, axis=0)
# Ignore mask elements is set to zero for argmax op.
masks = tf.where(valid_masks, masks, tf.zeros_like(masks))
if self._rescale_predictions:
# This part can only run on cpu/gpu due to dynamic image resizing.
flatten_predictions = []
flatten_masks = []
flatten_valid_masks = []
for mask, valid_mask, predicted_mask, image_info in zip(
masks, valid_masks, predictions, images_info):
rescale_size = tf.cast(
tf.math.ceil(image_info[1, :] / image_info[2, :]), tf.int32)
image_shape = tf.cast(image_info[0, :], tf.int32)
offsets = tf.cast(image_info[3, :], tf.int32)
predicted_mask = tf.image.resize(
predicted_mask,
rescale_size,
method=tf.image.ResizeMethod.BILINEAR)
predicted_mask = tf.image.crop_to_bounding_box(predicted_mask,
offsets[0], offsets[1],
image_shape[0],
image_shape[1])
mask = tf.image.crop_to_bounding_box(mask, 0, 0, image_shape[0],
image_shape[1])
valid_mask = tf.image.crop_to_bounding_box(valid_mask, 0, 0,
image_shape[0],
image_shape[1])
predicted_mask = tf.argmax(predicted_mask, axis=2)
flatten_predictions.append(tf.reshape(predicted_mask, shape=[1, -1]))
flatten_masks.append(tf.reshape(mask, shape=[1, -1]))
flatten_valid_masks.append(tf.reshape(valid_mask, shape=[1, -1]))
flatten_predictions = tf.concat(flatten_predictions, axis=1)
flatten_masks = tf.concat(flatten_masks, axis=1)
flatten_valid_masks = tf.concat(flatten_valid_masks, axis=1)
else:
predictions = tf.image.resize(
predictions,
tf.shape(masks)[1:3],
method=tf.image.ResizeMethod.BILINEAR)
predictions = tf.argmax(predictions, axis=3)
flatten_predictions = tf.reshape(predictions, shape=[-1])
flatten_masks = tf.reshape(masks, shape=[-1])
flatten_valid_masks = tf.reshape(valid_masks, shape=[-1])
super(PerClassIoU, self).update_state(
flatten_masks, flatten_predictions,
tf.cast(flatten_valid_masks, tf.float32))
......@@ -69,10 +69,10 @@ RESNET_SPECS = {
('bottleneck', 256, 36),
('bottleneck', 512, 3),
],
300: [
270: [
('bottleneck', 64, 4),
('bottleneck', 128, 36),
('bottleneck', 256, 54),
('bottleneck', 128, 29),
('bottleneck', 256, 53),
('bottleneck', 512, 4),
],
350: [
......@@ -81,6 +81,12 @@ RESNET_SPECS = {
('bottleneck', 256, 72),
('bottleneck', 512, 4),
],
420: [
('bottleneck', 64, 4),
('bottleneck', 128, 44),
('bottleneck', 256, 87),
('bottleneck', 512, 4),
],
}
......@@ -93,6 +99,8 @@ class ResNet(tf.keras.Model):
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
depth_multiplier=1.0,
stem_type='v0',
resnetd_shortcut=False,
replace_stem_max_pool=False,
se_ratio=None,
init_stochastic_depth_rate=0.0,
activation='relu',
......@@ -111,7 +119,11 @@ class ResNet(tf.keras.Model):
depth_multiplier: `float` a depth multiplier to uniformaly scale up all
layers in channel size in ResNet.
stem_type: `str` stem type of ResNet. Default to `v0`. If set to `v1`,
use ResNet-C type stem (https://arxiv.org/abs/1812.01187).
use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
resnetd_shortcut: `bool` whether to use ResNet-D shortcut in downsampling
blocks.
replace_stem_max_pool: `bool` if True, replace the max pool in stem with
a stride-2 conv,
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
init_stochastic_depth_rate: `float` initial stochastic depth rate.
activation: `str` name of the activation function.
......@@ -130,6 +142,8 @@ class ResNet(tf.keras.Model):
self._input_specs = input_specs
self._depth_multiplier = depth_multiplier
self._stem_type = stem_type
self._resnetd_shortcut = resnetd_shortcut
self._replace_stem_max_pool = replace_stem_max_pool
self._se_ratio = se_ratio
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._use_sync_bn = use_sync_bn
......@@ -213,7 +227,23 @@ class ResNet(tf.keras.Model):
else:
raise ValueError('Stem type {} not supported.'.format(stem_type))
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
if replace_stem_max_pool:
x = layers.Conv2D(
filters=int(64 * self._depth_multiplier),
kernel_size=3,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)(
x)
x = tf_utils.get_activation(activation)(x)
else:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
endpoints = {}
for i, spec in enumerate(RESNET_SPECS[model_id]):
......@@ -267,6 +297,7 @@ class ResNet(tf.keras.Model):
use_projection=True,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio,
resnetd_shortcut=self._resnetd_shortcut,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -283,6 +314,7 @@ class ResNet(tf.keras.Model):
use_projection=False,
stochastic_depth_drop_rate=stochastic_depth_drop_rate,
se_ratio=self._se_ratio,
resnetd_shortcut=self._resnetd_shortcut,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
......@@ -299,6 +331,8 @@ class ResNet(tf.keras.Model):
'model_id': self._model_id,
'depth_multiplier': self._depth_multiplier,
'stem_type': self._stem_type,
'resnetd_shortcut': self._resnetd_shortcut,
'replace_stem_max_pool': self._replace_stem_max_pool,
'activation': self._activation,
'se_ratio': self._se_ratio,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
......@@ -338,6 +372,8 @@ def build_resnet(
input_specs=input_specs,
depth_multiplier=backbone_cfg.depth_multiplier,
stem_type=backbone_cfg.stem_type,
resnetd_shortcut=backbone_cfg.resnetd_shortcut,
replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
se_ratio=backbone_cfg.se_ratio,
init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
activation=norm_activation_config.activation,
......
......@@ -84,20 +84,22 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
_ = network(inputs)
@parameterized.parameters(
(128, 34, 1, 'v0', None, 0.0, 1.0),
(128, 34, 1, 'v1', 0.25, 0.2, 1.25),
(128, 50, 4, 'v0', None, 0.0, 1.5),
(128, 50, 4, 'v1', 0.25, 0.2, 2.0),
(128, 34, 1, 'v0', None, 0.0, 1.0, False, False),
(128, 34, 1, 'v1', 0.25, 0.2, 1.25, True, True),
(128, 50, 4, 'v0', None, 0.0, 1.5, False, False),
(128, 50, 4, 'v1', 0.25, 0.2, 2.0, True, True),
)
def test_resnet_addons(self, input_size, model_id, endpoint_filter_scale,
stem_type, se_ratio, init_stochastic_depth_rate,
depth_multiplier):
def test_resnet_rs(self, input_size, model_id, endpoint_filter_scale,
stem_type, se_ratio, init_stochastic_depth_rate,
depth_multiplier, resnetd_shortcut, replace_stem_max_pool):
"""Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = resnet.ResNet(
model_id=model_id,
depth_multiplier=depth_multiplier,
stem_type=stem_type,
resnetd_shortcut=resnetd_shortcut,
replace_stem_max_pool=replace_stem_max_pool,
se_ratio=se_ratio,
init_stochastic_depth_rate=init_stochastic_depth_rate)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
......@@ -121,6 +123,8 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
depth_multiplier=1.0,
stem_type='v0',
se_ratio=None,
resnetd_shortcut=False,
replace_stem_max_pool=False,
init_stochastic_depth_rate=0.0,
use_sync_bn=False,
activation='relu',
......
......@@ -36,6 +36,7 @@ class ClassificationModel(tf.keras.Model):
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
skip_logits_layer: bool = False,
**kwargs):
"""Classification initialization function.
......@@ -55,6 +56,7 @@ class ClassificationModel(tf.keras.Model):
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
skip_logits_layer: `bool`, whether to skip the prediction layer.
**kwargs: keyword arguments to be passed.
"""
self._self_setattr_tracking = False
......@@ -88,12 +90,13 @@ class ClassificationModel(tf.keras.Model):
if add_head_batch_norm:
x = self._norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(
num_classes, kernel_initializer=kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
if not skip_logits_layer:
x = tf.keras.layers.Dropout(dropout_rate)(x)
x = tf.keras.layers.Dense(
num_classes, kernel_initializer=kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
super(ClassificationModel, self).__init__(
inputs=inputs, outputs=x, **kwargs)
......
......@@ -41,7 +41,8 @@ from official.vision.beta.modeling.layers import roi_sampler
def build_classification_model(
input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None):
l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False):
"""Builds the classification model."""
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
......@@ -58,7 +59,8 @@ def build_classification_model(
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
norm_epsilon=norm_activation_config.norm_epsilon,
skip_logits_layer=skip_logits_layer)
return model
......
......@@ -182,6 +182,7 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.cast(x, dtype=y.dtype)
x = tf.concat([x, y], axis=self._bn_axis)
elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
x = nn_layers.pyramid_feature_fusion(decoder_output,
......
......@@ -63,6 +63,7 @@ class ResidualBlock(tf.keras.layers.Layer):
strides,
use_projection=False,
se_ratio=None,
resnetd_shortcut=False,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
......@@ -84,6 +85,8 @@ class ResidualBlock(tf.keras.layers.Layer):
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: `bool` if True, apply the resnetd style modification to
the shortcut connection. Not implemented in residual blocks.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
......@@ -104,6 +107,7 @@ class ResidualBlock(tf.keras.layers.Layer):
self._strides = strides
self._use_projection = use_projection
self._se_ratio = se_ratio
self._resnetd_shortcut = resnetd_shortcut
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -191,6 +195,7 @@ class ResidualBlock(tf.keras.layers.Layer):
'strides': self._strides,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'resnetd_shortcut': self._resnetd_shortcut,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -235,6 +240,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
dilation_rate=1,
use_projection=False,
se_ratio=None,
resnetd_shortcut=False,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
......@@ -257,6 +263,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
resnetd_shortcut: `bool` if True, apply the resnetd style modification to
the shortcut connection.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
......@@ -278,6 +286,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self._dilation_rate = dilation_rate
self._use_projection = use_projection
self._se_ratio = se_ratio
self._resnetd_shortcut = resnetd_shortcut
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
......@@ -298,14 +307,27 @@ class BottleneckBlock(tf.keras.layers.Layer):
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
if self._resnetd_shortcut:
self._shortcut0 = tf.keras.layers.AveragePooling2D(
pool_size=2, strides=self._strides, padding='same')
self._shortcut1 = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
......@@ -378,6 +400,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
'dilation_rate': self._dilation_rate,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'resnetd_shortcut': self._resnetd_shortcut,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
......@@ -393,7 +416,11 @@ class BottleneckBlock(tf.keras.layers.Layer):
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
if self._resnetd_shortcut:
shortcut = self._shortcut0(shortcut)
shortcut = self._shortcut1(shortcut)
else:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
......
......@@ -15,7 +15,6 @@
"""TFDS Classification decoder."""
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
......@@ -27,10 +26,9 @@ class Decoder(decoder.Decoder):
def decode(self, serialized_example):
sample_dict = {
'image/encoded': tf.io.encode_jpeg(
serialized_example['image'], quality=100),
'image/class/label': serialized_example['label'],
'image/encoded':
tf.io.encode_jpeg(serialized_example['image'], quality=100),
'image/class/label':
serialized_example['label'],
}
return sample_dict
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection Data parser and processing for YOLO.
Parse image and ground truths in a dataset to training targets and package them
into (image, labels) tuple for RetinaNet.
"""
import tensorflow as tf
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.projects.yolo.ops import box_ops as yolo_box_ops
from official.vision.beta.projects.yolo.ops import preprocess_ops as yolo_preprocess_ops
class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
output_size,
num_classes,
fixed_size=True,
jitter_im=0.1,
jitter_boxes=0.005,
use_tie_breaker=True,
min_level=3,
max_level=5,
masks=None,
max_process_size=608,
min_process_size=320,
max_num_instances=200,
random_flip=True,
aug_rand_saturation=True,
aug_rand_brightness=True,
aug_rand_zoom=True,
aug_rand_hue=True,
anchors=None,
seed=10,
dtype=tf.float32):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: a `Tuple` for (width, height) of input image.
num_classes: a `Tensor` or `int` for the number of classes.
fixed_size: a `bool` if True all output images have the same size.
jitter_im: a `float` representing a pixel value that is the maximum jitter
applied to the image for data augmentation during training.
jitter_boxes: a `float` representing a pixel value that is the maximum
jitter applied to the bounding box for data augmentation during
training.
use_tie_breaker: boolean value for wether or not to use the tie_breaker.
min_level: `int` number of minimum level of the output feature pyramid.
max_level: `int` number of maximum level of the output feature pyramid.
masks: a `Tensor`, `List` or `numpy.ndarray` for anchor masks.
max_process_size: an `int` for maximum image width and height.
min_process_size: an `int` for minimum image width and height ,
max_num_instances: an `int` number of maximum number of instances in an
image.
random_flip: a `bool` if True, augment training with random horizontal
flip.
aug_rand_saturation: `bool`, if True, augment training with random
saturation.
aug_rand_brightness: `bool`, if True, augment training with random
brightness.
aug_rand_zoom: `bool`, if True, augment training with random zoom.
aug_rand_hue: `bool`, if True, augment training with random hue.
anchors: a `Tensor`, `List` or `numpy.ndarrray` for bounding box priors.
seed: an `int` for the seed used by tf.random
dtype: a `tf.dtypes.DType` object that represents the dtype the outputs
will be casted to. The available types are tf.float32, tf.float16, or
tf.bfloat16.
"""
self._net_down_scale = 2**max_level
self._num_classes = num_classes
self._image_w = (output_size[0] //
self._net_down_scale) * self._net_down_scale
self._image_h = (output_size[1] //
self._net_down_scale) * self._net_down_scale
self._max_process_size = max_process_size
self._min_process_size = min_process_size
self._fixed_size = fixed_size
self._anchors = anchors
self._masks = {
key: tf.convert_to_tensor(value) for key, value in masks.items()
}
self._use_tie_breaker = use_tie_breaker
self._jitter_im = 0.0 if jitter_im is None else jitter_im
self._jitter_boxes = 0.0 if jitter_boxes is None else jitter_boxes
self._max_num_instances = max_num_instances
self._random_flip = random_flip
self._aug_rand_saturation = aug_rand_saturation
self._aug_rand_brightness = aug_rand_brightness
self._aug_rand_zoom = aug_rand_zoom
self._aug_rand_hue = aug_rand_hue
self._seed = seed
self._dtype = dtype
def _build_grid(self, raw_true, width, batch=False, use_tie_breaker=False):
mask = self._masks
for key in self._masks.keys():
if not batch:
mask[key] = yolo_preprocess_ops.build_grided_gt(
raw_true, self._masks[key], width // 2**int(key),
raw_true['bbox'].dtype, use_tie_breaker)
else:
mask[key] = yolo_preprocess_ops.build_batch_grided_gt(
raw_true, self._masks[key], width // 2**int(key),
raw_true['bbox'].dtype, use_tie_breaker)
return mask
def _parse_train_data(self, data):
"""Generates images and labels that are usable for model training.
Args:
data: a dict of Tensors produced by the decoder.
Returns:
images: the image tensor.
labels: a dict of Tensors that contains labels.
"""
shape = tf.shape(data['image'])
image = data['image'] / 255
boxes = data['groundtruth_boxes']
width = shape[0]
height = shape[1]
image, boxes = yolo_preprocess_ops.fit_preserve_aspect_ratio(
image,
boxes,
width=width,
height=height,
target_dim=self._max_process_size)
image_shape = tf.shape(image)[:2]
if self._random_flip:
image, boxes, _ = preprocess_ops.random_horizontal_flip(
image, boxes, seed=self._seed)
randscale = self._image_w // self._net_down_scale
if not self._fixed_size:
do_scale = tf.greater(
tf.random.uniform([], minval=0, maxval=1, seed=self._seed), 0.5)
if do_scale:
# This scales the image to a random multiple of net_down_scale
# between 320 to 608
randscale = tf.random.uniform(
[],
minval=self._min_process_size // self._net_down_scale,
maxval=self._max_process_size // self._net_down_scale,
seed=self._seed,
dtype=tf.int32) * self._net_down_scale
if self._jitter_boxes != 0.0:
boxes = box_ops.denormalize_boxes(boxes, image_shape)
boxes = box_ops.jitter_boxes(boxes, 0.025)
boxes = box_ops.normalize_boxes(boxes, image_shape)
# YOLO loss function uses x-center, y-center format
boxes = yolo_box_ops.yxyx_to_xcycwh(boxes)
if self._jitter_im != 0.0:
image, boxes = yolo_preprocess_ops.random_translate(
image, boxes, self._jitter_im, seed=self._seed)
if self._aug_rand_zoom:
image, boxes = yolo_preprocess_ops.resize_crop_filter(
image,
boxes,
default_width=self._image_w,
default_height=self._image_h,
target_width=randscale,
target_height=randscale)
image = tf.image.resize(image, (416, 416), preserve_aspect_ratio=False)
if self._aug_rand_brightness:
image = tf.image.random_brightness(
image=image, max_delta=.1) # Brightness
if self._aug_rand_saturation:
image = tf.image.random_saturation(
image=image, lower=0.75, upper=1.25) # Saturation
if self._aug_rand_hue:
image = tf.image.random_hue(image=image, max_delta=.3) # Hue
image = tf.clip_by_value(image, 0.0, 1.0)
# Find the best anchor for the ground truth labels to maximize the iou
best_anchors = yolo_preprocess_ops.get_best_anchor(
boxes, self._anchors, width=self._image_w, height=self._image_h)
# Padding
boxes = preprocess_ops.clip_or_pad_to_fixed_size(boxes,
self._max_num_instances, 0)
classes = preprocess_ops.clip_or_pad_to_fixed_size(
data['groundtruth_classes'], self._max_num_instances, -1)
best_anchors = preprocess_ops.clip_or_pad_to_fixed_size(
best_anchors, self._max_num_instances, 0)
area = preprocess_ops.clip_or_pad_to_fixed_size(data['groundtruth_area'],
self._max_num_instances, 0)
is_crowd = preprocess_ops.clip_or_pad_to_fixed_size(
tf.cast(data['groundtruth_is_crowd'], tf.int32),
self._max_num_instances, 0)
labels = {
'source_id': data['source_id'],
'bbox': tf.cast(boxes, self._dtype),
'classes': tf.cast(classes, self._dtype),
'area': tf.cast(area, self._dtype),
'is_crowd': is_crowd,
'best_anchors': tf.cast(best_anchors, self._dtype),
'width': width,
'height': height,
'num_detections': tf.shape(data['groundtruth_classes'])[0],
}
if self._fixed_size:
grid = self._build_grid(
labels, self._image_w, use_tie_breaker=self._use_tie_breaker)
labels.update({'grid_form': grid})
return image, labels
def _parse_eval_data(self, data):
"""Generates images and labels that are usable for model training.
Args:
data: a dict of Tensors produced by the decoder.
Returns:
images: the image tensor.
labels: a dict of Tensors that contains labels.
"""
shape = tf.shape(data['image'])
image = data['image'] / 255
boxes = data['groundtruth_boxes']
width = shape[0]
height = shape[1]
image, boxes = yolo_preprocess_ops.fit_preserve_aspect_ratio(
image, boxes, width=width, height=height, target_dim=self._image_w)
boxes = yolo_box_ops.yxyx_to_xcycwh(boxes)
# Find the best anchor for the ground truth labels to maximize the iou
best_anchors = yolo_preprocess_ops.get_best_anchor(
boxes, self._anchors, width=self._image_w, height=self._image_h)
boxes = yolo_preprocess_ops.pad_max_instances(boxes,
self._max_num_instances, 0)
classes = yolo_preprocess_ops.pad_max_instances(data['groundtruth_classes'],
self._max_num_instances, 0)
best_anchors = yolo_preprocess_ops.pad_max_instances(
best_anchors, self._max_num_instances, 0)
area = yolo_preprocess_ops.pad_max_instances(data['groundtruth_area'],
self._max_num_instances, 0)
is_crowd = yolo_preprocess_ops.pad_max_instances(
tf.cast(data['groundtruth_is_crowd'], tf.int32),
self._max_num_instances, 0)
labels = {
'source_id': data['source_id'],
'bbox': tf.cast(boxes, self._dtype),
'classes': tf.cast(classes, self._dtype),
'area': tf.cast(area, self._dtype),
'is_crowd': is_crowd,
'best_anchors': tf.cast(best_anchors, self._dtype),
'width': width,
'height': height,
'num_detections': tf.shape(data['groundtruth_classes'])[0],
}
grid = self._build_grid(
labels,
self._image_w,
batch=False,
use_tie_breaker=self._use_tie_breaker)
labels.update({'grid_form': grid})
return image, labels
def _postprocess_fn(self, image, label):
randscale = self._image_w // self._net_down_scale
if not self._fixed_size:
do_scale = tf.greater(
tf.random.uniform([], minval=0, maxval=1, seed=self._seed), 0.5)
if do_scale:
# This scales the image to a random multiple of net_down_scale
# between 320 to 608
randscale = tf.random.uniform(
[],
minval=self._min_process_size // self._net_down_scale,
maxval=self._max_process_size // self._net_down_scale,
seed=self._seed,
dtype=tf.int32) * self._net_down_scale
width = randscale
image = tf.image.resize(image, (width, width))
grid = self._build_grid(
label, width, batch=True, use_tie_breaker=self._use_tie_breaker)
label.update({'grid_form': grid})
return image, label
def postprocess_fn(self, is_training=True):
return self._postprocess_fn if not self._fixed_size and is_training else None
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test case for YOLO detection dataloader configuration definition."""
from absl.testing import parameterized
import dataclasses
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.modeling import hyperparams
from official.vision.beta.dataloaders import tfds_detection_decoders
from official.vision.beta.projects.yolo.dataloaders import yolo_detection_input
@dataclasses.dataclass
class Parser(hyperparams.Config):
"""Dummy configuration for parser."""
output_size: int = (416, 416)
num_classes: int = 80
fixed_size: bool = True
jitter_im: float = 0.1
jitter_boxes: float = 0.005
min_process_size: int = 320
max_process_size: int = 608
max_num_instances: int = 200
random_flip: bool = True
seed: int = 10
shuffle_buffer_size: int = 10000
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
tfds_name: str = 'coco/2017'
tfds_split: str = 'train'
global_batch_size: int = 10
is_training: bool = True
dtype: str = 'float16'
decoder = None
parser: Parser = Parser()
shuffle_buffer_size: int = 10
tfds_download: bool = False
class YoloDetectionInputTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(('training', True), ('testing', False))
def test_yolo_input(self, is_training):
params = DataConfig(is_training=is_training)
decoder = tfds_detection_decoders.MSCOCODecoder()
anchors = [[12.0, 19.0], [31.0, 46.0], [96.0, 54.0], [46.0, 114.0],
[133.0, 127.0], [79.0, 225.0], [301.0, 150.0], [172.0, 286.0],
[348.0, 340.0]]
masks = {'3': [0, 1, 2], '4': [3, 4, 5], '5': [6, 7, 8]}
parser = yolo_detection_input.Parser(
output_size=params.parser.output_size,
num_classes=params.parser.num_classes,
fixed_size=params.parser.fixed_size,
jitter_im=params.parser.jitter_im,
jitter_boxes=params.parser.jitter_boxes,
min_process_size=params.parser.min_process_size,
max_process_size=params.parser.max_process_size,
max_num_instances=params.parser.max_num_instances,
random_flip=params.parser.random_flip,
seed=params.parser.seed,
anchors=anchors,
masks=masks)
postprocess_fn = parser.postprocess_fn(is_training=is_training)
reader = input_reader.InputReader(params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(
params.is_training))
dataset = reader.read(input_context=None).batch(10).take(1)
if postprocess_fn:
image, _ = postprocess_fn(
*tf.data.experimental.get_single_element(dataset))
else:
image, _ = tf.data.experimental.get_single_element(dataset)
print(image.shape)
self.assertAllEqual(image.shape, (10, 10, 416, 416, 3))
self.assertTrue(
tf.reduce_all(tf.math.logical_and(image >= 0, image <= 1)))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment