Commit add6e22f authored by Vishnu Banna's avatar Vishnu Banna
Browse files

datapipeline update

parent d09d4bef
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for object detection.
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import tensorflow as tf
from official.vision.beta.dataloaders import tf_example_decoder
def _coco91_to_80(classif, box, areas, iscrowds):
"""Function used to reduce COCO 91 to COCO 80, or to convert from the 2017
foramt to the 2014 format"""
# Vector where index i coralates to the class at index[i].
x = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,
63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85,
86, 87, 88, 89, 90
]
no = tf.expand_dims(tf.convert_to_tensor(x), axis=0)
# Resahpe the classes to in order to build a class mask.
ce = tf.expand_dims(classif, axis=-1)
# One hot the classificiations to match the 80 class format.
ind = ce == tf.cast(no, ce.dtype)
# Select the max values.
co = tf.reshape(tf.math.argmax(tf.cast(ind, tf.float32), axis=-1), [-1])
ind = tf.where(tf.reduce_any(ind, axis=-1))
# Gather the valuable instances.
classif = tf.gather_nd(co, ind)
box = tf.gather_nd(box, ind)
areas = tf.gather_nd(areas, ind)
iscrowds = tf.gather_nd(iscrowds, ind)
# Restate the number of viable detections, ideally it should be the same.
num_detections = tf.shape(classif)[0]
return classif, box, areas, iscrowds, num_detections
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self,
coco91_to_80,
include_mask=False,
regenerate_source_id=False,
mask_binarize_threshold=None):
if coco91_to_80 and include_mask:
raise ValueError("If masks are included you cannot \
convert coco from the 91 class format \
to the 80 class format")
self._coco91_to_80 = coco91_to_80
super().__init__(
include_mask=include_mask,
regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=mask_binarize_threshold
)
def decode(self, serialized_example):
"""Decode the serialized example.
Args:
serialized_example: a single serialized tf.Example string.
Returns:
decoded_tensors: a dictionary of tensors with the following fields:
- source_id: a string scalar tensor.
- image: a uint8 tensor of shape [None, None, 3].
- height: an integer scalar tensor.
- width: an integer scalar tensor.
- groundtruth_classes: a int64 tensor of shape [None].
- groundtruth_is_crowd: a bool tensor of shape [None].
- groundtruth_area: a float32 tensor of shape [None].
- groundtruth_boxes: a float32 tensor of shape [None, 4].
- groundtruth_instance_masks: a float32 tensor of shape
[None, None, None].
- groundtruth_instance_masks_png: a string tensor of shape [None].
"""
decoded_tensors = super().decode(serialized_example)
if self._coco91_to_80:
(decoded_tensors['groundtruth_classes'],
decoded_tensors['groundtruth_boxes'],
decoded_tensors['groundtruth_area'],
decoded_tensors['groundtruth_is_crowd'],
_) = _coco91_to_80(decoded_tensors['groundtruth_classes'],
decoded_tensors['groundtruth_boxes'],
decoded_tensors['groundtruth_area'],
decoded_tensors['groundtruth_is_crowd'])
return decoded_tensors
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. """ Detection Data parser and processing for YOLO.
# Parse image and ground truths in a dataset to training targets and package them
# Licensed under the Apache License, Version 2.0 (the "License"); into (image, labels) tuple for RetinaNet.
# you may not use this file except in compliance with the License. """
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Detection Data parser and processing for YOLO."""
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from official.vision.beta.projects.yolo.ops import preprocessing_ops from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.projects.yolo.ops import box_ops as box_utils from official.vision.beta.projects.yolo.ops import anchor
from official.vision.beta.ops import preprocess_ops from official.vision.beta.ops import preprocess_ops
from official.vision.beta.ops import box_ops as bbox_ops
from official.vision.beta.dataloaders import parser, utils from official.vision.beta.dataloaders import parser, utils
def _coco91_to_80(classif, box, areas, iscrowds):
"""Function used to reduce COCO 91 to COCO 80, or to convert from the 2017
foramt to the 2014 format"""
# Vector where index i coralates to the class at index[i].
x = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,
44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62,
63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85,
86, 87, 88, 89, 90
]
no = tf.expand_dims(tf.convert_to_tensor(x), axis=0)
# Resahpe the classes to in order to build a class mask.
ce = tf.expand_dims(classif, axis=-1)
# One hot the classificiations to match the 80 class format.
ind = ce == tf.cast(no, ce.dtype)
# Select the max values.
co = tf.reshape(tf.math.argmax(tf.cast(ind, tf.float32), axis=-1), [-1])
ind = tf.where(tf.reduce_any(ind, axis=-1))
# Gather the valuable instances.
classif = tf.gather_nd(co, ind)
box = tf.gather_nd(box, ind)
areas = tf.gather_nd(areas, ind)
iscrowds = tf.gather_nd(iscrowds, ind)
# Restate the number of viable detections, ideally it should be the same.
num_detections = tf.shape(classif)[0]
return classif, box, areas, iscrowds, num_detections
class Parser(parser.Parser): class Parser(parser.Parser):
"""Parse the dataset in to the YOLO model format. """ """Parse the dataset in to the YOLO model format. """
def __init__( def __init__(
self, self,
output_size, output_size,
masks,
anchors, anchors,
strides, expanded_strides,
anchor_free_limits=None, anchor_free_limits=None,
max_num_instances=200, max_num_instances=200,
area_thresh=0.1, area_thresh=0.1,
...@@ -82,23 +37,18 @@ class Parser(parser.Parser): ...@@ -82,23 +37,18 @@ class Parser(parser.Parser):
anchor_t=4.0, anchor_t=4.0,
scale_xy=None, scale_xy=None,
best_match_only=False, best_match_only=False,
coco91to80=False,
darknet=False, darknet=False,
use_tie_breaker=True, use_tie_breaker=True,
dtype='float32', dtype='float32',
seed=None, seed=None):
):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
Args: Args:
output_size: `Tensor` or `List` for [height, width] of output image. The output_size: `Tensor` or `List` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level. output_size should be divided by the largest feature stride 2^max_level.
masks: `Dict[List[int]]` of values indicating the indexes in the anchors: `Dict[List[Union[int, float]]]` values for each anchor box.
list of anchor boxes to use an each prediction level between min_level expanded_strides: `Dict[int]` for how much the model scales down the
and max_level. each level must have a list of indexes. images at the largest level.
anchors: `List[List[Union[int, float]]]` values for each anchor box.
strides: `Dict[int]` for how much the model scales down the images at the
largest level.
anchor_free_limits: `List` the box sizes that will be allowed at each FPN anchor_free_limits: `List` the box sizes that will be allowed at each FPN
level as is done in the FCOS and YOLOX paper for anchor free box level as is done in the FCOS and YOLOX paper for anchor free box
assignment. Anchor free will perform worse than Anchor based, but only assignment. Anchor free will perform worse than Anchor based, but only
...@@ -144,9 +94,7 @@ class Parser(parser.Parser): ...@@ -144,9 +94,7 @@ class Parser(parser.Parser):
there should be one value for scale_xy for each level from min_level to there should be one value for scale_xy for each level from min_level to
max_level. max_level.
best_match_only: `boolean` indicating how boxes are selected for best_match_only: `boolean` indicating how boxes are selected for
optimization. optimization.
coco91to80: `bool` for wether to convert coco91 to coco80 to minimize
model parameters.
darknet: `boolean` indicating which data pipeline to use. Setting to True darknet: `boolean` indicating which data pipeline to use. Setting to True
swaps the pipeline to output images realtive to Yolov4 and older. swaps the pipeline to output images realtive to Yolov4 and older.
use_tie_breaker: `boolean` indicating whether to use the anchor threshold use_tie_breaker: `boolean` indicating whether to use the anchor threshold
...@@ -155,25 +103,23 @@ class Parser(parser.Parser): ...@@ -155,25 +103,23 @@ class Parser(parser.Parser):
from {"float32", "float16", "bfloat16"}. from {"float32", "float16", "bfloat16"}.
seed: `int` the seed for random number generation. seed: `int` the seed for random number generation.
""" """
for key in masks.keys(): for key in anchors.keys():
# Assert that the width and height is viable # Assert that the width and height is viable
assert output_size[1] % strides[str(key)] == 0 assert output_size[1] % expanded_strides[str(key)] == 0
assert output_size[0] % strides[str(key)] == 0 assert output_size[0] % expanded_strides[str(key)] == 0
# scale of each FPN level # scale of each FPN level
self._strides = strides self._strides = expanded_strides
# Set the width and height properly and base init: # Set the width and height properly and base init:
self._coco91to80 = coco91to80
self._image_w = output_size[1] self._image_w = output_size[1]
self._image_h = output_size[0] self._image_h = output_size[0]
# Set the anchor boxes and masks for each scale # Set the anchor boxes for each scale
self._anchors = anchors self._anchors = anchors
self._anchor_free_limits = anchor_free_limits self._anchor_free_limits = anchor_free_limits
self._masks = {
key: tf.convert_to_tensor(value) for key, value in masks.items() # anchor labeling paramters
}
self._use_tie_breaker = use_tie_breaker self._use_tie_breaker = use_tie_breaker
self._best_match_only = best_match_only self._best_match_only = best_match_only
self._max_num_instances = max_num_instances self._max_num_instances = max_num_instances
...@@ -202,7 +148,7 @@ class Parser(parser.Parser): ...@@ -202,7 +148,7 @@ class Parser(parser.Parser):
self._darknet = darknet self._darknet = darknet
self._area_thresh = area_thresh self._area_thresh = area_thresh
keys = list(self._masks.keys()) keys = list(self._anchors.keys())
if self._anchor_free_limits is not None: if self._anchor_free_limits is not None:
maxim = 2000 maxim = 2000
...@@ -218,10 +164,15 @@ class Parser(parser.Parser): ...@@ -218,10 +164,15 @@ class Parser(parser.Parser):
# Set the data type based on input string # Set the data type based on input string
self._dtype = dtype self._dtype = dtype
def _get_identity_info(self, image): self._label_builder = anchor.YoloAnchorLabeler(
"""Get an identity image op to pad all info vectors, this is used because anchors = self._anchors,
graph compilation if there are a variable number of info objects in a list. match_threshold=self._anchor_t,
""" best_matches_only=self._best_match_only,
use_tie_breaker=self._use_tie_breaker
)
def _pad_infos_object(self, image):
"""Get a Tensor to pad the info object list."""
shape_ = tf.shape(image) shape_ = tf.shape(image)
val = tf.stack([ val = tf.stack([
tf.cast(shape_[:2], tf.float32), tf.cast(shape_[:2], tf.float32),
...@@ -234,16 +185,16 @@ class Parser(parser.Parser): ...@@ -234,16 +185,16 @@ class Parser(parser.Parser):
def _jitter_scale(self, image, shape, letter_box, jitter, random_pad, def _jitter_scale(self, image, shape, letter_box, jitter, random_pad,
aug_scale_min, aug_scale_max, translate, angle, aug_scale_min, aug_scale_max, translate, angle,
perspective): perspective):
"""Distort and scale each input image"""
infos = []
if (aug_scale_min != 1.0 or aug_scale_max != 1.0): if (aug_scale_min != 1.0 or aug_scale_max != 1.0):
crop_only = True crop_only = True
# jitter gives you only one info object, resize and crop gives you one, # jitter gives you only one info object, resize and crop gives you one,
# if crop only then there can be 1 form jitter and 1 from crop # if crop only then there can be 1 form jitter and 1 from crop
reps = 1 infos.append(self._pad_infos_object(image))
else: else:
crop_only = False crop_only = False
reps = 0 image, crop_info, _ = preprocessing_ops.resize_and_jitter_image(
infos = []
image, info_a, _ = preprocessing_ops.resize_and_jitter_image(
image, image,
shape, shape,
letter_box=letter_box, letter_box=letter_box,
...@@ -252,10 +203,7 @@ class Parser(parser.Parser): ...@@ -252,10 +203,7 @@ class Parser(parser.Parser):
random_pad=random_pad, random_pad=random_pad,
seed=self._seed, seed=self._seed,
) )
infos.extend(info_a) infos.extend(crop_info)
stale_a = self._get_identity_info(image)
for _ in range(reps):
infos.append(stale_a)
image, _, affine = preprocessing_ops.affine_warp_image( image, _, affine = preprocessing_ops.affine_warp_image(
image, image,
shape, shape,
...@@ -269,21 +217,8 @@ class Parser(parser.Parser): ...@@ -269,21 +217,8 @@ class Parser(parser.Parser):
) )
return image, infos, affine return image, infos, affine
def reorg91to80(self, data):
"""Function used to reduce COCO 91 to COCO 80, or to convert from the 2017
foramt to the 2014 format"""
if self._coco91to80:
(data['groundtruth_classes'], data['groundtruth_boxes'],
data['groundtruth_area'], data['groundtruth_is_crowd'],
_) = _coco91_to_80(data['groundtruth_classes'],
data['groundtruth_boxes'], data['groundtruth_area'],
data['groundtruth_is_crowd'])
return data
def _parse_train_data(self, data): def _parse_train_data(self, data):
"""Parses data for training and evaluation.""" """Parses data for training."""
# Down size coco 91 to coco 80 if the option is selected.
data = self.reorg91to80(data)
# Initialize the shape constants. # Initialize the shape constants.
image = data['image'] image = data['image']
...@@ -316,12 +251,16 @@ class Parser(parser.Parser): ...@@ -316,12 +251,16 @@ class Parser(parser.Parser):
else: else:
image = tf.image.resize( image = tf.image.resize(
image, (self._image_h, self._image_w), method='nearest') image, (self._image_h, self._image_w), method='nearest')
inds = tf.cast(tf.range(0, tf.shape(boxes)[0]), tf.int64) output_size = tf.cast([640, 640], tf.float32)
info = self._get_identity_info(image) boxes_ = bbox_ops.denormalize_boxes(boxes, output_size)
inds = bbox_ops.get_non_empty_box_indices(boxes_)
boxes = tf.gather(boxes, inds)
classes = tf.gather(classes, inds)
info = self._pad_infos_object(image)
# Apply scaling to the hue saturation and brightness of an image. # Apply scaling to the hue saturation and brightness of an image.
image = tf.cast(image, dtype=self._dtype) image = tf.cast(image, dtype=self._dtype)
image = image / 255 image = image / 255.0
image = preprocessing_ops.image_rand_hsv( image = preprocessing_ops.image_rand_hsv(
image, image,
self._aug_rand_hue, self._aug_rand_hue,
...@@ -331,30 +270,20 @@ class Parser(parser.Parser): ...@@ -331,30 +270,20 @@ class Parser(parser.Parser):
darknet=self._darknet) darknet=self._darknet)
# Cast the image to the selcted datatype. # Cast the image to the selcted datatype.
image, labels = self._build_label( image, labels = self._build_label(image, boxes, classes,
image, info, inds, data, is_training=True)
boxes,
classes,
self._image_w,
self._image_h,
info,
inds,
data,
is_training=True)
return image, labels return image, labels
def _parse_eval_data(self, data): def _parse_eval_data(self, data):
# Down size coco 91 to coco 80 if the option is selected. """Parses data for evaluation."""
data = self.reorg91to80(data)
# Get the image shape constants and cast the image to the selcted datatype. # Get the image shape constants and cast the image to the selcted datatype.
image = tf.cast(data['image'], dtype=self._dtype) image = tf.cast(data['image'], dtype=self._dtype)
boxes = data['groundtruth_boxes'] boxes = data['groundtruth_boxes']
classes = data['groundtruth_classes'] classes = data['groundtruth_classes']
height, width = self._image_h, self._image_w
image, infos, _ = preprocessing_ops.resize_and_jitter_image( image, infos, _ = preprocessing_ops.resize_and_jitter_image(
image, [height, width], image, [self._image_h, self._image_w],
letter_box=self._letter_box, letter_box=self._letter_box,
random_pad=False, random_pad=False,
shiftx=0.5, shiftx=0.5,
...@@ -362,7 +291,7 @@ class Parser(parser.Parser): ...@@ -362,7 +291,7 @@ class Parser(parser.Parser):
jitter=0.0) jitter=0.0)
# Clip and clean boxes. # Clip and clean boxes.
image = image / 255 image = image / 255.0
boxes, inds = preprocessing_ops.apply_infos( boxes, inds = preprocessing_ops.apply_infos(
boxes, infos, shuffle_boxes=False, area_thresh=0.0, augment=True) boxes, infos, shuffle_boxes=False, area_thresh=0.0, augment=True)
classes = tf.gather(classes, inds) classes = tf.gather(classes, inds)
...@@ -372,8 +301,6 @@ class Parser(parser.Parser): ...@@ -372,8 +301,6 @@ class Parser(parser.Parser):
image, image,
boxes, boxes,
classes, classes,
width,
height,
info, info,
inds, inds,
data, data,
...@@ -381,6 +308,7 @@ class Parser(parser.Parser): ...@@ -381,6 +308,7 @@ class Parser(parser.Parser):
return image, labels return image, labels
def set_shape(self, values, pad_axis=0, pad_value=0, inds=None, scale=1): def set_shape(self, values, pad_axis=0, pad_value=0, inds=None, scale=1):
"""Calls set shape for all input objects."""
if inds is not None: if inds is not None:
values = tf.gather(values, inds) values = tf.gather(values, inds)
vshape = values.get_shape().as_list() vshape = values.get_shape().as_list()
...@@ -396,8 +324,8 @@ class Parser(parser.Parser): ...@@ -396,8 +324,8 @@ class Parser(parser.Parser):
values.set_shape(vshape) values.set_shape(vshape)
return values return values
def _build_grid(self, raw_true, width, height, use_tie_breaker=False): def _build_grid(self, boxes, classes, width, height):
'''Private function for building the full scale object and class grid.''' """Private function for building the full scale object and class grid."""
indexes = {} indexes = {}
updates = {} updates = {}
true_grids = {} true_grids = {}
...@@ -406,27 +334,19 @@ class Parser(parser.Parser): ...@@ -406,27 +334,19 @@ class Parser(parser.Parser):
self._anchor_free_limits = [0.0] + self._anchor_free_limits + [np.inf] self._anchor_free_limits = [0.0] + self._anchor_free_limits + [np.inf]
# for each prediction path generate a properly scaled output prediction map # for each prediction path generate a properly scaled output prediction map
for i, key in enumerate(self._masks.keys()): for i, key in enumerate(self._anchors.keys()):
if self._anchor_free_limits is not None: if self._anchor_free_limits is not None:
fpn_limits = self._anchor_free_limits[i:i + 2] fpn_limits = self._anchor_free_limits[i:i + 2]
else: else:
fpn_limits = None fpn_limits = None
# build the actual grid as well and the list of boxes and classes AND
# their index in the prediction grid
scale_xy = self._scale_xy[key] if not self._darknet else 1 scale_xy = self._scale_xy[key] if not self._darknet else 1
(indexes[key], updates[key],
true_grids[key]) = preprocessing_ops.build_grided_gt_ind( indexes[key], updates[key], true_grids[key] = self._label_builder(
raw_true, key, boxes, classes, self._anchors[key],
self._masks[key], width, height, self._strides[str(key)],
width // self._strides[str(key)], scale_xy, self._max_num_instances * self._scale_up[key],
height // self._strides[str(key)], fpn_limits = fpn_limits)
raw_true['bbox'].dtype,
scale_xy,
self._scale_up[key],
use_tie_breaker,
self._strides[str(key)],
fpn_limits=fpn_limits)
# set/fix the shapes # set/fix the shapes
indexes[key] = self.set_shape(indexes[key], -2, None, None, indexes[key] = self.set_shape(indexes[key], -2, None, None,
...@@ -442,54 +362,39 @@ class Parser(parser.Parser): ...@@ -442,54 +362,39 @@ class Parser(parser.Parser):
image, image,
gt_boxes, gt_boxes,
gt_classes, gt_classes,
width,
height,
info, info,
inds, inds,
data, data,
is_training=True): is_training=True):
"""Label construction for both the train and eval data. """ """Label construction for both the train and eval data. """
width = self._image_w
height = self._image_h
# Set the image shape. # Set the image shape.
imshape = image.get_shape().as_list() imshape = image.get_shape().as_list()
imshape[-1] = 3 imshape[-1] = 3
image.set_shape(imshape) image.set_shape(imshape)
# Get the best anchors. labels = dict()
boxes = box_utils.yxyx_to_xcycwh(gt_boxes) labels['inds'], labels['upds'], labels['true_conf'] = self._build_grid(
best_anchors, ious = preprocessing_ops.get_best_anchor( gt_boxes, gt_classes, width, height)
boxes,
self._anchors,
width=width,
height=height,
iou_thresh=self._anchor_t,
best_match_only=self._best_match_only)
# Set/fix the boxes shape. # Set/fix the boxes shape.
boxes = self.set_shape(boxes, pad_axis=0, pad_value=0) boxes = self.set_shape(gt_boxes, pad_axis=0, pad_value=0)
classes = self.set_shape(gt_classes, pad_axis=0, pad_value=-1) classes = self.set_shape(gt_classes, pad_axis=0, pad_value=-1)
best_anchors = self.set_shape(best_anchors, pad_axis=0, pad_value=-1)
ious = self.set_shape(ious, pad_axis=0, pad_value=0)
area = self.set_shape( area = self.set_shape(
data['groundtruth_area'], pad_axis=0, pad_value=0, inds=inds) data['groundtruth_area'], pad_axis=0, pad_value=0, inds=inds)
is_crowd = self.set_shape( is_crowd = self.set_shape(
data['groundtruth_is_crowd'], pad_axis=0, pad_value=0, inds=inds) data['groundtruth_is_crowd'], pad_axis=0, pad_value=0, inds=inds)
# Build the dictionary set. # Build the dictionary set.
labels = { labels.update({
'source_id': utils.process_source_id(data['source_id']), 'source_id': utils.process_source_id(data['source_id']),
'bbox': tf.cast(boxes, dtype=self._dtype), 'bbox': tf.cast(boxes, dtype=self._dtype),
'classes': tf.cast(classes, dtype=self._dtype), 'classes': tf.cast(classes, dtype=self._dtype),
'best_anchors': tf.cast(best_anchors, dtype=self._dtype), })
'best_iou_match': ious,
}
# Build the grid formatted for loss computation in model output format.
labels['inds'], labels['upds'], labels['true_conf'] = self._build_grid(
labels, width, height, use_tie_breaker=self._use_tie_breaker)
# Update the labels dictionary. # Update the labels dictionary.
labels['bbox'] = box_utils.xcycwh_to_yxyx(labels['bbox'])
if not is_training: if not is_training:
# Sets up groundtruth data for evaluation. # Sets up groundtruth data for evaluation.
groundtruths = { groundtruths = {
...@@ -509,3 +414,5 @@ class Parser(parser.Parser): ...@@ -509,3 +414,5 @@ class Parser(parser.Parser):
groundtruths, self._max_num_instances) groundtruths, self._max_num_instances)
labels['groundtruths'] = groundtruths labels['groundtruths'] = groundtruths
return image, labels return image, labels
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.gen_math_ops import maximum, minimum
from official.vision.beta.projects.yolo.ops import box_ops
from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.projects.yolo.ops import loss_utils
def get_best_anchor(y_true,
anchors,
stride,
width=1,
height=1,
iou_thresh=0.25,
best_match_only=False,
use_tie_breaker=True):
"""
get the correct anchor that is assoiciated with each box using IOU
Args:
y_true: tf.Tensor[] for the list of bounding boxes in the yolo format
anchors: list or tensor for the anchor boxes to be used in prediction
found via Kmeans
width: int for the image width
height: int for the image height
Return:
tf.Tensor: y_true with the anchor associated with each ground truth
box known
"""
with tf.name_scope('get_best_anchor'):
width = tf.cast(width, dtype=tf.float32)
height = tf.cast(height, dtype=tf.float32)
scaler = tf.convert_to_tensor([width, height])
true_wh = tf.cast(y_true[..., 2:4], dtype=tf.float32) * scaler
anchors = tf.cast(anchors, dtype=tf.float32)/stride
k = tf.shape(anchors)[0]
anchors = tf.concat([tf.zeros_like(anchors), anchors], axis=-1)
truth_comp = tf.concat([tf.zeros_like(true_wh), true_wh], axis=-1)
if iou_thresh >= 1.0:
anchors = tf.expand_dims(anchors, axis=-2)
truth_comp = tf.expand_dims(truth_comp, axis=-3)
aspect = truth_comp[..., 2:4] / anchors[..., 2:4]
aspect = tf.where(tf.math.is_nan(aspect), tf.zeros_like(aspect), aspect)
aspect = tf.maximum(aspect, 1 / aspect)
aspect = tf.where(tf.math.is_nan(aspect), tf.zeros_like(aspect), aspect)
aspect = tf.reduce_max(aspect, axis=-1)
values, indexes = tf.math.top_k(
tf.transpose(-aspect, perm=[1, 0]),
k=tf.cast(k, dtype=tf.int32),
sorted=True)
values = -values
ind_mask = tf.cast(values < iou_thresh, dtype=indexes.dtype)
else:
# iou_raw = box_ops.compute_iou(truth_comp, anchors)
truth_comp = box_ops.xcycwh_to_yxyx(truth_comp)
anchors = box_ops.xcycwh_to_yxyx(anchors)
iou_raw = box_ops.aggregated_comparitive_iou(
truth_comp,
anchors,
iou_type=3,
)
values, indexes = tf.math.top_k(
iou_raw, #tf.transpose(iou_raw, perm=[0, 2, 1]),
k=tf.cast(k, dtype=tf.int32),
sorted=True)
ind_mask = tf.cast(values >= iou_thresh, dtype=indexes.dtype)
# pad the indexs such that all values less than the thresh are -1
# add one, multiply the mask to zeros all the bad locations
# subtract 1 makeing all the bad locations 0.
if best_match_only:
iou_index = ((indexes[..., 0:] + 1) * ind_mask[..., 0:]) - 1
elif use_tie_breaker:
iou_index = tf.concat([
tf.expand_dims(indexes[..., 0], axis=-1),
((indexes[..., 1:] + 1) * ind_mask[..., 1:]) - 1], axis=-1)
else:
iou_index = tf.concat([
tf.expand_dims(indexes[..., 0], axis=-1),
tf.zeros_like(indexes[..., 1:]) - 1], axis=-1)
return tf.cast(iou_index, dtype=tf.float32), tf.cast(values, dtype=tf.float32)
class YoloAnchorLabeler:
def __init__(self,
anchors = None,
match_threshold = 0.25,
best_matches_only = False,
use_tie_breaker = True):
self.anchors = anchors
self.masks = self._get_mask()
self.match_threshold = match_threshold
self.best_matches_only = best_matches_only
self.use_tie_breaker = use_tie_breaker
def _get_mask(self):
masks = {}
start = 0
minimum = int(min(self.anchors.keys()))
maximum = int(max(self.anchors.keys()))
for i in range(minimum, maximum + 1):
per_scale = len(self.anchors[str(i)])
masks[str(i)] = list(range(start, per_scale + start))
start += per_scale
return masks
def _tie_breaking_search(self, anchors, mask, boxes, classes):
mask = tf.cast(tf.reshape(mask, [1, 1, 1, -1]), anchors.dtype)
anchors = tf.expand_dims(anchors, axis=-1)
viable = tf.where(tf.squeeze(anchors == mask, axis = 0))
gather_id, _, anchor_id = tf.split(viable, 3, axis = -1)
boxes = tf.gather_nd(boxes, gather_id)
classes = tf.gather_nd(classes, gather_id)
classes = tf.expand_dims(classes, axis = -1)
classes = tf.cast(classes, boxes.dtype)
anchor_id = tf.cast(anchor_id, boxes.dtype)
return boxes, classes, anchor_id
def _get_anchor_id(self, key, boxes, classes, anchors, width, height, stride):
"""Find the object anchor assignments in an anchor based paradigm. """
# find the best anchor
num_anchors = len(anchors)
if self.best_matches_only:
# get the best anchor for each box
iou_index, _ = get_best_anchor(boxes, anchors, stride,
width=width, height=height,
best_match_only=True,
iou_thresh=self.match_threshold)
mask = range(num_anchors)
else:
# stitch and search boxes across fpn levels
anchorsvec = []
for stitch in self.anchors.keys():
anchorsvec.extend(self.anchors[stitch])
# get the best anchor for each box
iou_index, _ = get_best_anchor(boxes, anchorsvec, stride,
width=width, height=height,
best_match_only=False,
use_tie_breaker=self.use_tie_breaker,
iou_thresh=self.match_threshold)
mask = self.masks[key]
# search for the correct box to use
(boxes,
classes,
anchors) = self._tie_breaking_search(iou_index, mask, boxes, classes)
return boxes, classes, anchors, num_anchors
def _get_centers(self, boxes, classes, anchors, width, height, offset):
"""Find the object center assignments in an anchor based paradigm. """
grid_xy, wh = tf.split(boxes, 2, axis = -1)
wh_scale = tf.cast(tf.convert_to_tensor([width, height]), boxes.dtype)
grid_xy = grid_xy * wh_scale
centers = tf.math.floor(grid_xy)
if offset != 0.0:
clamp = lambda x, ma: tf.maximum(
tf.minimum(x, tf.cast(ma, x.dtype)), tf.zeros_like(x))
grid_xy_index = grid_xy - centers
positive_shift = ((grid_xy_index < offset) & (grid_xy > 1.))
negative_shift = (
(grid_xy_index > (1 - offset)) & (grid_xy < (wh_scale - 1.)))
zero , _ = tf.split(tf.ones_like(positive_shift), 2, axis = -1)
shift_mask = tf.concat(
[zero, positive_shift, negative_shift], axis = -1)
offset = tf.cast([[0, 0], [1, 0],
[0, 1], [-1, 0],
[0, -1]], offset.dtype) * offset
num_shifts = tf.shape(shift_mask)
num_shifts = num_shifts[-1]
boxes = tf.tile(tf.expand_dims(boxes, axis = -2), [1, num_shifts, 1])
classes = tf.tile(tf.expand_dims(classes, axis = -2), [1, num_shifts, 1])
anchors = tf.tile(tf.expand_dims(anchors, axis = -2), [1, num_shifts, 1])
shift_mask = tf.cast(shift_mask, boxes.dtype)
shift_ind = shift_mask * tf.range(0, num_shifts, dtype = boxes.dtype)
shift_ind = shift_ind - (1 - shift_mask)
shift_ind = tf.expand_dims(shift_ind, axis = -1)
boxes_and_centers = tf.concat(
[boxes, classes, anchors, shift_ind], axis = -1)
boxes_and_centers = tf.reshape(boxes_and_centers, [-1, 7])
_, center_ids = tf.split(boxes_and_centers, [6, 1], axis = -1)
#center_ids = tf.squeeze(center_ids, axis = -1)
select = tf.where(center_ids >= 0)
select, _ = tf.split(select, 2, axis = -1)
boxes_and_centers = tf.gather_nd(boxes_and_centers, select)
# center_ids = tf.cast(center_ids, tf.int32)
center_ids = tf.gather_nd(center_ids, select)
center_ids = tf.cast(center_ids, tf.int32)
shifts = tf.gather_nd(offset, center_ids)
boxes, classes, anchors, _ = tf.split(boxes_and_centers,
[4, 1, 1, 1], axis = -1)
grid_xy, _ = tf.split(boxes, 2, axis = -1)
centers = tf.math.floor(grid_xy * wh_scale - shifts)
centers = clamp(centers, wh_scale - 1)
x, y = tf.split(centers, 2, axis = -1)
centers = tf.cast(tf.concat([y, x, anchors], axis = -1), tf.int32)
return boxes, classes, centers
def _get_anchor_free(self,
boxes,
classes,
height,
width,
stride,
fpn_limits,
center_radius=2.5):
"""Find the box assignements in an anchor free paradigm. """
gen = loss_utils.GridGenerator(
masks=None, anchors=[[1, 1]], scale_anchors=stride)
grid_points = gen(width, height, 1, boxes.dtype)[0]
grid_points = tf.squeeze(grid_points, axis=0)
box_list = boxes
class_list = classes
grid_points = (grid_points + 0.5) * stride
x_centers, y_centers = grid_points[..., 0], grid_points[..., 1]
boxes *= (tf.convert_to_tensor([width, height, width, height]) * stride)
tlbr_boxes = box_ops.xcycwh_to_yxyx(boxes)
boxes = tf.reshape(boxes, [1, 1, -1, 4])
tlbr_boxes = tf.reshape(tlbr_boxes, [1, 1, -1, 4])
if self.use_tie_breaker:
area = tf.reduce_prod(boxes[..., 2:], axis = -1)
# check if the box is in the receptive feild of the this fpn level
b_t = y_centers - tlbr_boxes[..., 0]
b_l = x_centers - tlbr_boxes[..., 1]
b_b = tlbr_boxes[..., 2] - y_centers
b_r = tlbr_boxes[..., 3] - x_centers
box_delta = tf.stack([b_t, b_l, b_b, b_r], axis=-1)
if fpn_limits is not None:
max_reg_targets_per_im = tf.reduce_max(box_delta, axis=-1)
gt_min = max_reg_targets_per_im >= fpn_limits[0]
gt_max = max_reg_targets_per_im <= fpn_limits[1]
is_in_boxes = tf.logical_and(gt_min, gt_max)
else:
is_in_boxes = tf.reduce_min(box_delta, axis=-1) > 0.0
is_in_boxes_all = tf.reduce_any(is_in_boxes, axis=(0, 1), keepdims=True)
# check if the center is in the receptive feild of the this fpn level
c_t = y_centers - (boxes[..., 1] - center_radius * stride)
c_l = x_centers - (boxes[..., 0] - center_radius * stride)
c_b = (boxes[..., 1] + center_radius * stride) - y_centers
c_r = (boxes[..., 0] + center_radius * stride) - x_centers
centers_delta = tf.stack([c_t, c_l, c_b, c_r], axis=-1)
is_in_centers = tf.reduce_min(centers_delta, axis=-1) > 0.0
is_in_centers_all = tf.reduce_any(is_in_centers, axis=(0, 1), keepdims=True)
# colate all masks to get the final locations
is_in_index = tf.logical_or(is_in_boxes_all, is_in_centers_all)
is_in_boxes_and_center = tf.logical_and(is_in_boxes, is_in_centers)
is_in_boxes_and_center = tf.logical_and(is_in_index, is_in_boxes_and_center)
if self.use_tie_breaker:
inf = 10000000
boxes_all = tf.cast(is_in_boxes_and_center, area.dtype)
boxes_all = ((boxes_all * area) + ((1 - boxes_all) * inf))
boxes_min = tf.reduce_min(boxes_all, axis = -1, keepdims = True)
boxes_min = tf.where(boxes_min == inf, -1.0, boxes_min)
is_in_boxes_and_center = boxes_all == boxes_min
# construct the index update grid
reps = tf.reduce_sum(tf.cast(is_in_boxes_and_center, tf.int16), axis=-1)
indexes = tf.cast(tf.where(is_in_boxes_and_center), tf.int32)
y, x, t = tf.split(indexes, 3, axis=-1)
boxes = tf.gather_nd(box_list, t)
classes = tf.cast(tf.gather_nd(class_list, t), boxes.dtype)
reps = tf.gather_nd(reps, tf.concat([y, x], axis=-1))
reps = tf.cast(tf.expand_dims(reps, axis=-1), boxes.dtype)
classes = tf.cast(tf.expand_dims(classes, axis=-1), boxes.dtype)
conf = tf.ones_like(classes)
# return the samples and the indexes
samples = tf.concat([boxes, conf, classes], axis=-1)
indexes = tf.concat([y, x, tf.zeros_like(t)], axis=-1)
return indexes, samples
def __call__(self,
key,
boxes,
classes,
anchors,
width,
height,
stride,
scale_xy,
num_instances,
fpn_limits = None):
"""Builds the labels for a single image, not functional in batch mode.
Args:
boxes: `Tensor` of shape [None, 4] indicating the object locations in
an image.
classes: `Tensor` of shape [None] indicating the each objects classes.
anchors: `List[List[int, float]]` representing the anchor boxes to build
the model against.
width: `int` for the images width.
height: `int` for the images height.
stride: `int` for how much the image gets scaled at this level.
scale_xy: `float` for the center shifts to apply when finding center
assignments for a box.
num_instances: `int` for the maximum number of expanded boxes to allow.
fpn_limits: `List[int]` given no anchor boxes this is used to limit the
boxes assied to the each fpn level based on the levels receptive feild.
Returns:
centers: `Tensor` of shape [None, 3] of indexes in the final grid where
boxes are located.
updates: `Tensor` of shape [None, 8] the value to place in the final grid.
full: `Tensor` of [width/stride, height/stride, num_anchors, 1] holding
a mask of where boxes are locates for confidence losses.
"""
boxes = box_ops.yxyx_to_xcycwh(boxes)
width //= stride
height //= stride
width = tf.cast(width, boxes.dtype)
height = tf.cast(height, boxes.dtype)
if fpn_limits is None:
offset = tf.cast(0.5 * (scale_xy - 1), boxes.dtype)
(boxes, classes,
anchors, num_anchors) = self._get_anchor_id(key, boxes, classes, anchors,
width, height, stride)
boxes, classes, centers = self._get_centers(boxes, classes, anchors,
width, height, offset)
ind_mask = tf.ones_like(classes)
updates = tf.concat([boxes, ind_mask, classes], axis = -1)
else:
(centers, updates) = self._get_anchor_free(boxes, classes, height,
width, stride, fpn_limits)
boxes, ind_mask, classes = tf.split(updates, [4, 1, 1], axis = -1)
num_anchors = 1
width = tf.cast(width, tf.int32)
height = tf.cast(height, tf.int32)
full = tf.zeros([height, width, num_anchors, 1], dtype=classes.dtype)
full = tf.tensor_scatter_nd_add(full, centers, ind_mask)
centers = preprocessing_ops.pad_max_instances(
centers, int(num_instances), pad_value=0, pad_axis=0)
updates = preprocessing_ops.pad_max_instances(
updates, int(num_instances), pad_value=0, pad_axis=0)
return centers, updates, full
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mosaic data aug for YOLO."""
import random import random
import tensorflow as tf import tensorflow as tf
import tensorflow_addons as tfa import tensorflow_addons as tfa
from official.vision.beta.projects.yolo.ops import preprocessing_ops from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops
class Mosaic(object): class Mosaic:
"""Stitch together sets of 4 images to generate samples with more boxes.""" """Stitch together sets of 4 images to generate samples with more boxes."""
def __init__(self, def __init__(self,
...@@ -36,6 +23,7 @@ class Mosaic(object): ...@@ -36,6 +23,7 @@ class Mosaic(object):
aug_rand_perspective=0.0, aug_rand_perspective=0.0,
aug_rand_translate=0.0, aug_rand_translate=0.0,
random_pad=False, random_pad=False,
random_flip=False,
area_thresh=0.1, area_thresh=0.1,
seed=None): seed=None):
"""Initializes parameters for mosaic. """Initializes parameters for mosaic.
...@@ -91,6 +79,7 @@ class Mosaic(object): ...@@ -91,6 +79,7 @@ class Mosaic(object):
self._aug_rand_translate = aug_rand_translate self._aug_rand_translate = aug_rand_translate
self._aug_rand_angle = aug_rand_angle self._aug_rand_angle = aug_rand_angle
self._aug_rand_perspective = aug_rand_perspective self._aug_rand_perspective = aug_rand_perspective
self._random_flip = random_flip
self._deterministic = seed != None self._deterministic = seed != None
self._seed = seed if seed is not None else random.randint(0, 2**30) self._seed = seed if seed is not None else random.randint(0, 2**30)
...@@ -116,6 +105,12 @@ class Mosaic(object): ...@@ -116,6 +105,12 @@ class Mosaic(object):
[self._output_size[1] * 2, self._output_size[0] * 2, 3]) [self._output_size[1] * 2, self._output_size[0] * 2, 3])
return cut, ishape return cut, ishape
def _select_ind(self, inds, *args):
items = []
for item in args:
items.append(tf.gather(item, inds))
return items
def _augment_image(self, def _augment_image(self,
image, image,
boxes, boxes,
...@@ -126,13 +121,16 @@ class Mosaic(object): ...@@ -126,13 +121,16 @@ class Mosaic(object):
ys=0.0, ys=0.0,
cut=None): cut=None):
"""Process a single image prior to the application of patching.""" """Process a single image prior to the application of patching."""
# Randomly flip the image horizontally. if self._random_flip:
letter_box = self._letter_box # Randomly flip the image horizontally.
image, boxes, _ = preprocess_ops.random_horizontal_flip(
image, boxes, seed=self._seed)
#augment the image without resizing
image, infos, crop_points = preprocessing_ops.resize_and_jitter_image( image, infos, crop_points = preprocessing_ops.resize_and_jitter_image(
image, [self._output_size[0], self._output_size[1]], image, [self._output_size[0], self._output_size[1]],
random_pad=False, random_pad=False,
letter_box=letter_box, letter_box=self._letter_box,
jitter=self._random_crop, jitter=self._random_crop,
shiftx=xs, shiftx=xs,
shifty=ys, shifty=ys,
...@@ -147,9 +145,7 @@ class Mosaic(object): ...@@ -147,9 +145,7 @@ class Mosaic(object):
shuffle_boxes=False, shuffle_boxes=False,
augment=True, augment=True,
seed=self._seed) seed=self._seed)
classes = tf.gather(classes, inds) classes, is_crowd, area = self._select_ind(inds, classes, is_crowd, area)
is_crowd = tf.gather(is_crowd, inds)
area = tf.gather(area, inds)
return image, boxes, classes, is_crowd, area, crop_points return image, boxes, classes, is_crowd, area, crop_points
def _mosaic_crop_image(self, image, boxes, classes, is_crowd, area): def _mosaic_crop_image(self, image, boxes, classes, is_crowd, area):
...@@ -173,7 +169,11 @@ class Mosaic(object): ...@@ -173,7 +169,11 @@ class Mosaic(object):
boxes = box_ops.denormalize_boxes(boxes, shape[:2]) boxes = box_ops.denormalize_boxes(boxes, shape[:2])
boxes = boxes + tf.cast([ch, cw, ch, cw], boxes.dtype) boxes = boxes + tf.cast([ch, cw, ch, cw], boxes.dtype)
boxes = box_ops.clip_boxes(boxes, shape[:2]) boxes = box_ops.clip_boxes(boxes, shape[:2])
inds = box_ops.get_non_empty_box_indices(boxes)
boxes = box_ops.normalize_boxes(boxes, shape[:2]) boxes = box_ops.normalize_boxes(boxes, shape[:2])
boxes, classes, is_crowd, area = self._select_ind(inds, boxes, classes, is_crowd, area)
# warp and scale the fully stitched sample # warp and scale the fully stitched sample
image, _, affine = preprocessing_ops.affine_warp_image( image, _, affine = preprocessing_ops.affine_warp_image(
...@@ -190,15 +190,9 @@ class Mosaic(object): ...@@ -190,15 +190,9 @@ class Mosaic(object):
# clip and clean boxes # clip and clean boxes
boxes, inds = preprocessing_ops.apply_infos( boxes, inds = preprocessing_ops.apply_infos(
boxes, boxes, None, affine=affine, area_thresh=self._area_thresh,
None,
affine=affine,
area_thresh=self._area_thresh,
augment=True,
seed=self._seed) seed=self._seed)
classes = tf.gather(classes, inds) classes, is_crowd, area = self._select_ind(inds, classes, is_crowd, area)
is_crowd = tf.gather(is_crowd, inds)
area = tf.gather(area, inds)
return image, boxes, classes, is_crowd, area, area return image, boxes, classes, is_crowd, area, area
def scale_boxes(self, patch, ishape, boxes, classes, xs, ys): def scale_boxes(self, patch, ishape, boxes, classes, xs, ys):
...@@ -224,8 +218,6 @@ class Mosaic(object): ...@@ -224,8 +218,6 @@ class Mosaic(object):
sample['image'], sample['groundtruth_boxes'], sample['image'], sample['groundtruth_boxes'],
sample['groundtruth_classes'], sample['groundtruth_is_crowd'], sample['groundtruth_classes'], sample['groundtruth_is_crowd'],
sample['groundtruth_area'], shiftx, shifty, cut) sample['groundtruth_area'], shiftx, shifty, cut)
if cut is None and ishape is None:
cut, ishape = self._generate_cut()
(boxes, classes) = self.scale_boxes(image, ishape, boxes, classes, (boxes, classes) = self.scale_boxes(image, ishape, boxes, classes,
1 - shiftx, 1 - shifty) 1 - shiftx, 1 - shifty)
...@@ -235,7 +227,6 @@ class Mosaic(object): ...@@ -235,7 +227,6 @@ class Mosaic(object):
sample['groundtruth_classes'] = classes sample['groundtruth_classes'] = classes
sample['groundtruth_is_crowd'] = is_crowd sample['groundtruth_is_crowd'] = is_crowd
sample['groundtruth_area'] = area sample['groundtruth_area'] = area
sample['cut'] = cut
sample['shiftx'] = shiftx sample['shiftx'] = shiftx
sample['shifty'] = shifty sample['shifty'] = shifty
sample['crop_points'] = crop_points sample['crop_points'] = crop_points
...@@ -284,7 +275,9 @@ class Mosaic(object): ...@@ -284,7 +275,9 @@ class Mosaic(object):
sample['num_detections'] = tf.shape(sample['groundtruth_boxes'])[1] sample['num_detections'] = tf.shape(sample['groundtruth_boxes'])[1]
sample['is_mosaic'] = tf.cast(1.0, tf.bool) sample['is_mosaic'] = tf.cast(1.0, tf.bool)
del sample['shiftx'], sample['shifty'], sample['crop_points'], sample['cut'] del sample['shiftx']
del sample['shifty']
del sample['crop_points']
return sample return sample
def _mosaic(self, one, two, three, four): def _mosaic(self, one, two, three, four):
...@@ -349,6 +342,7 @@ class Mosaic(object): ...@@ -349,6 +342,7 @@ class Mosaic(object):
def _apply(self, dataset): def _apply(self, dataset):
"""Apply mosaic to an input dataset.""" """Apply mosaic to an input dataset."""
determ = self._deterministic determ = self._deterministic
dataset = dataset.prefetch(tf.data.AUTOTUNE)
one = dataset.shuffle(100, seed=self._seed, reshuffle_each_iteration=True) one = dataset.shuffle(100, seed=self._seed, reshuffle_each_iteration=True)
two = dataset.shuffle( two = dataset.shuffle(
100, seed=self._seed + 1, reshuffle_each_iteration=True) 100, seed=self._seed + 1, reshuffle_each_iteration=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment