Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8fba84f8
Commit
8fba84f8
authored
Feb 25, 2021
by
A. Unique TensorFlower
Browse files
Merge pull request #9678 from tensorflow:purdue-yolo
PiperOrigin-RevId: 359601927
parents
ba627d4e
83b992c5
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1373 additions
and
6 deletions
+1373
-6
official/vision/beta/projects/yolo/dataloaders/__init__.py
official/vision/beta/projects/yolo/dataloaders/__init__.py
+1
-0
official/vision/beta/projects/yolo/dataloaders/classification_tfds_decoder.py
.../projects/yolo/dataloaders/classification_tfds_decoder.py
+4
-6
official/vision/beta/projects/yolo/dataloaders/yolo_detection_input.py
...on/beta/projects/yolo/dataloaders/yolo_detection_input.py
+319
-0
official/vision/beta/projects/yolo/dataloaders/yolo_detection_input_test.py
...ta/projects/yolo/dataloaders/yolo_detection_input_test.py
+104
-0
official/vision/beta/projects/yolo/ops/__init__.py
official/vision/beta/projects/yolo/ops/__init__.py
+1
-0
official/vision/beta/projects/yolo/ops/box_ops.py
official/vision/beta/projects/yolo/ops/box_ops.py
+297
-0
official/vision/beta/projects/yolo/ops/box_ops_test.py
official/vision/beta/projects/yolo/ops/box_ops_test.py
+56
-0
official/vision/beta/projects/yolo/ops/preprocess_ops.py
official/vision/beta/projects/yolo/ops/preprocess_ops.py
+524
-0
official/vision/beta/projects/yolo/ops/preprocess_ops_test.py
...cial/vision/beta/projects/yolo/ops/preprocess_ops_test.py
+67
-0
No files found.
official/vision/beta/projects/yolo/dataloaders/__init__.py
0 → 100644
View file @
8fba84f8
official/vision/beta/projects/yolo/dataloaders/classification_tfds_decoder.py
View file @
8fba84f8
...
...
@@ -15,7 +15,6 @@
"""TFDS Classification decoder."""
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
...
...
@@ -27,10 +26,9 @@ class Decoder(decoder.Decoder):
def
decode
(
self
,
serialized_example
):
sample_dict
=
{
'image/encoded'
:
tf
.
io
.
encode_jpeg
(
serialized_example
[
'image'
],
quality
=
100
),
'image/class/label'
:
serialized_example
[
'label'
],
'image/encoded'
:
tf
.
io
.
encode_jpeg
(
serialized_example
[
'image'
],
quality
=
100
),
'image/class/label'
:
serialized_example
[
'label'
],
}
return
sample_dict
official/vision/beta/projects/yolo/dataloaders/yolo_detection_input.py
0 → 100644
View file @
8fba84f8
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection Data parser and processing for YOLO.
Parse image and ground truths in a dataset to training targets and package them
into (image, labels) tuple for RetinaNet.
"""
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.projects.yolo.ops
import
box_ops
as
yolo_box_ops
from
official.vision.beta.projects.yolo.ops
import
preprocess_ops
as
yolo_preprocess_ops
class
Parser
(
parser
.
Parser
):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
output_size
,
num_classes
,
fixed_size
=
True
,
jitter_im
=
0.1
,
jitter_boxes
=
0.005
,
use_tie_breaker
=
True
,
min_level
=
3
,
max_level
=
5
,
masks
=
None
,
max_process_size
=
608
,
min_process_size
=
320
,
max_num_instances
=
200
,
random_flip
=
True
,
aug_rand_saturation
=
True
,
aug_rand_brightness
=
True
,
aug_rand_zoom
=
True
,
aug_rand_hue
=
True
,
anchors
=
None
,
seed
=
10
,
dtype
=
tf
.
float32
):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: a `Tuple` for (width, height) of input image.
num_classes: a `Tensor` or `int` for the number of classes.
fixed_size: a `bool` if True all output images have the same size.
jitter_im: a `float` representing a pixel value that is the maximum jitter
applied to the image for data augmentation during training.
jitter_boxes: a `float` representing a pixel value that is the maximum
jitter applied to the bounding box for data augmentation during
training.
use_tie_breaker: boolean value for wether or not to use the tie_breaker.
min_level: `int` number of minimum level of the output feature pyramid.
max_level: `int` number of maximum level of the output feature pyramid.
masks: a `Tensor`, `List` or `numpy.ndarray` for anchor masks.
max_process_size: an `int` for maximum image width and height.
min_process_size: an `int` for minimum image width and height ,
max_num_instances: an `int` number of maximum number of instances in an
image.
random_flip: a `bool` if True, augment training with random horizontal
flip.
aug_rand_saturation: `bool`, if True, augment training with random
saturation.
aug_rand_brightness: `bool`, if True, augment training with random
brightness.
aug_rand_zoom: `bool`, if True, augment training with random zoom.
aug_rand_hue: `bool`, if True, augment training with random hue.
anchors: a `Tensor`, `List` or `numpy.ndarrray` for bounding box priors.
seed: an `int` for the seed used by tf.random
dtype: a `tf.dtypes.DType` object that represents the dtype the outputs
will be casted to. The available types are tf.float32, tf.float16, or
tf.bfloat16.
"""
self
.
_net_down_scale
=
2
**
max_level
self
.
_num_classes
=
num_classes
self
.
_image_w
=
(
output_size
[
0
]
//
self
.
_net_down_scale
)
*
self
.
_net_down_scale
self
.
_image_h
=
(
output_size
[
1
]
//
self
.
_net_down_scale
)
*
self
.
_net_down_scale
self
.
_max_process_size
=
max_process_size
self
.
_min_process_size
=
min_process_size
self
.
_fixed_size
=
fixed_size
self
.
_anchors
=
anchors
self
.
_masks
=
{
key
:
tf
.
convert_to_tensor
(
value
)
for
key
,
value
in
masks
.
items
()
}
self
.
_use_tie_breaker
=
use_tie_breaker
self
.
_jitter_im
=
0.0
if
jitter_im
is
None
else
jitter_im
self
.
_jitter_boxes
=
0.0
if
jitter_boxes
is
None
else
jitter_boxes
self
.
_max_num_instances
=
max_num_instances
self
.
_random_flip
=
random_flip
self
.
_aug_rand_saturation
=
aug_rand_saturation
self
.
_aug_rand_brightness
=
aug_rand_brightness
self
.
_aug_rand_zoom
=
aug_rand_zoom
self
.
_aug_rand_hue
=
aug_rand_hue
self
.
_seed
=
seed
self
.
_dtype
=
dtype
def
_build_grid
(
self
,
raw_true
,
width
,
batch
=
False
,
use_tie_breaker
=
False
):
mask
=
self
.
_masks
for
key
in
self
.
_masks
.
keys
():
if
not
batch
:
mask
[
key
]
=
yolo_preprocess_ops
.
build_grided_gt
(
raw_true
,
self
.
_masks
[
key
],
width
//
2
**
int
(
key
),
raw_true
[
'bbox'
].
dtype
,
use_tie_breaker
)
else
:
mask
[
key
]
=
yolo_preprocess_ops
.
build_batch_grided_gt
(
raw_true
,
self
.
_masks
[
key
],
width
//
2
**
int
(
key
),
raw_true
[
'bbox'
].
dtype
,
use_tie_breaker
)
return
mask
def
_parse_train_data
(
self
,
data
):
"""Generates images and labels that are usable for model training.
Args:
data: a dict of Tensors produced by the decoder.
Returns:
images: the image tensor.
labels: a dict of Tensors that contains labels.
"""
shape
=
tf
.
shape
(
data
[
'image'
])
image
=
data
[
'image'
]
/
255
boxes
=
data
[
'groundtruth_boxes'
]
width
=
shape
[
0
]
height
=
shape
[
1
]
image
,
boxes
=
yolo_preprocess_ops
.
fit_preserve_aspect_ratio
(
image
,
boxes
,
width
=
width
,
height
=
height
,
target_dim
=
self
.
_max_process_size
)
image_shape
=
tf
.
shape
(
image
)[:
2
]
if
self
.
_random_flip
:
image
,
boxes
,
_
=
preprocess_ops
.
random_horizontal_flip
(
image
,
boxes
,
seed
=
self
.
_seed
)
randscale
=
self
.
_image_w
//
self
.
_net_down_scale
if
not
self
.
_fixed_size
:
do_scale
=
tf
.
greater
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
seed
=
self
.
_seed
),
0.5
)
if
do_scale
:
# This scales the image to a random multiple of net_down_scale
# between 320 to 608
randscale
=
tf
.
random
.
uniform
(
[],
minval
=
self
.
_min_process_size
//
self
.
_net_down_scale
,
maxval
=
self
.
_max_process_size
//
self
.
_net_down_scale
,
seed
=
self
.
_seed
,
dtype
=
tf
.
int32
)
*
self
.
_net_down_scale
if
self
.
_jitter_boxes
!=
0.0
:
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
image_shape
)
boxes
=
box_ops
.
jitter_boxes
(
boxes
,
0.025
)
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_shape
)
# YOLO loss function uses x-center, y-center format
boxes
=
yolo_box_ops
.
yxyx_to_xcycwh
(
boxes
)
if
self
.
_jitter_im
!=
0.0
:
image
,
boxes
=
yolo_preprocess_ops
.
random_translate
(
image
,
boxes
,
self
.
_jitter_im
,
seed
=
self
.
_seed
)
if
self
.
_aug_rand_zoom
:
image
,
boxes
=
yolo_preprocess_ops
.
resize_crop_filter
(
image
,
boxes
,
default_width
=
self
.
_image_w
,
default_height
=
self
.
_image_h
,
target_width
=
randscale
,
target_height
=
randscale
)
image
=
tf
.
image
.
resize
(
image
,
(
416
,
416
),
preserve_aspect_ratio
=
False
)
if
self
.
_aug_rand_brightness
:
image
=
tf
.
image
.
random_brightness
(
image
=
image
,
max_delta
=
.
1
)
# Brightness
if
self
.
_aug_rand_saturation
:
image
=
tf
.
image
.
random_saturation
(
image
=
image
,
lower
=
0.75
,
upper
=
1.25
)
# Saturation
if
self
.
_aug_rand_hue
:
image
=
tf
.
image
.
random_hue
(
image
=
image
,
max_delta
=
.
3
)
# Hue
image
=
tf
.
clip_by_value
(
image
,
0.0
,
1.0
)
# Find the best anchor for the ground truth labels to maximize the iou
best_anchors
=
yolo_preprocess_ops
.
get_best_anchor
(
boxes
,
self
.
_anchors
,
width
=
self
.
_image_w
,
height
=
self
.
_image_h
)
# Padding
boxes
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
self
.
_max_num_instances
,
0
)
classes
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
data
[
'groundtruth_classes'
],
self
.
_max_num_instances
,
-
1
)
best_anchors
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
best_anchors
,
self
.
_max_num_instances
,
0
)
area
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
data
[
'groundtruth_area'
],
self
.
_max_num_instances
,
0
)
is_crowd
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
tf
.
cast
(
data
[
'groundtruth_is_crowd'
],
tf
.
int32
),
self
.
_max_num_instances
,
0
)
labels
=
{
'source_id'
:
data
[
'source_id'
],
'bbox'
:
tf
.
cast
(
boxes
,
self
.
_dtype
),
'classes'
:
tf
.
cast
(
classes
,
self
.
_dtype
),
'area'
:
tf
.
cast
(
area
,
self
.
_dtype
),
'is_crowd'
:
is_crowd
,
'best_anchors'
:
tf
.
cast
(
best_anchors
,
self
.
_dtype
),
'width'
:
width
,
'height'
:
height
,
'num_detections'
:
tf
.
shape
(
data
[
'groundtruth_classes'
])[
0
],
}
if
self
.
_fixed_size
:
grid
=
self
.
_build_grid
(
labels
,
self
.
_image_w
,
use_tie_breaker
=
self
.
_use_tie_breaker
)
labels
.
update
({
'grid_form'
:
grid
})
return
image
,
labels
def
_parse_eval_data
(
self
,
data
):
"""Generates images and labels that are usable for model training.
Args:
data: a dict of Tensors produced by the decoder.
Returns:
images: the image tensor.
labels: a dict of Tensors that contains labels.
"""
shape
=
tf
.
shape
(
data
[
'image'
])
image
=
data
[
'image'
]
/
255
boxes
=
data
[
'groundtruth_boxes'
]
width
=
shape
[
0
]
height
=
shape
[
1
]
image
,
boxes
=
yolo_preprocess_ops
.
fit_preserve_aspect_ratio
(
image
,
boxes
,
width
=
width
,
height
=
height
,
target_dim
=
self
.
_image_w
)
boxes
=
yolo_box_ops
.
yxyx_to_xcycwh
(
boxes
)
# Find the best anchor for the ground truth labels to maximize the iou
best_anchors
=
yolo_preprocess_ops
.
get_best_anchor
(
boxes
,
self
.
_anchors
,
width
=
self
.
_image_w
,
height
=
self
.
_image_h
)
boxes
=
yolo_preprocess_ops
.
pad_max_instances
(
boxes
,
self
.
_max_num_instances
,
0
)
classes
=
yolo_preprocess_ops
.
pad_max_instances
(
data
[
'groundtruth_classes'
],
self
.
_max_num_instances
,
0
)
best_anchors
=
yolo_preprocess_ops
.
pad_max_instances
(
best_anchors
,
self
.
_max_num_instances
,
0
)
area
=
yolo_preprocess_ops
.
pad_max_instances
(
data
[
'groundtruth_area'
],
self
.
_max_num_instances
,
0
)
is_crowd
=
yolo_preprocess_ops
.
pad_max_instances
(
tf
.
cast
(
data
[
'groundtruth_is_crowd'
],
tf
.
int32
),
self
.
_max_num_instances
,
0
)
labels
=
{
'source_id'
:
data
[
'source_id'
],
'bbox'
:
tf
.
cast
(
boxes
,
self
.
_dtype
),
'classes'
:
tf
.
cast
(
classes
,
self
.
_dtype
),
'area'
:
tf
.
cast
(
area
,
self
.
_dtype
),
'is_crowd'
:
is_crowd
,
'best_anchors'
:
tf
.
cast
(
best_anchors
,
self
.
_dtype
),
'width'
:
width
,
'height'
:
height
,
'num_detections'
:
tf
.
shape
(
data
[
'groundtruth_classes'
])[
0
],
}
grid
=
self
.
_build_grid
(
labels
,
self
.
_image_w
,
batch
=
False
,
use_tie_breaker
=
self
.
_use_tie_breaker
)
labels
.
update
({
'grid_form'
:
grid
})
return
image
,
labels
def
_postprocess_fn
(
self
,
image
,
label
):
randscale
=
self
.
_image_w
//
self
.
_net_down_scale
if
not
self
.
_fixed_size
:
do_scale
=
tf
.
greater
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
seed
=
self
.
_seed
),
0.5
)
if
do_scale
:
# This scales the image to a random multiple of net_down_scale
# between 320 to 608
randscale
=
tf
.
random
.
uniform
(
[],
minval
=
self
.
_min_process_size
//
self
.
_net_down_scale
,
maxval
=
self
.
_max_process_size
//
self
.
_net_down_scale
,
seed
=
self
.
_seed
,
dtype
=
tf
.
int32
)
*
self
.
_net_down_scale
width
=
randscale
image
=
tf
.
image
.
resize
(
image
,
(
width
,
width
))
grid
=
self
.
_build_grid
(
label
,
width
,
batch
=
True
,
use_tie_breaker
=
self
.
_use_tie_breaker
)
label
.
update
({
'grid_form'
:
grid
})
return
image
,
label
def
postprocess_fn
(
self
,
is_training
=
True
):
return
self
.
_postprocess_fn
if
not
self
.
_fixed_size
and
is_training
else
None
official/vision/beta/projects/yolo/dataloaders/yolo_detection_input_test.py
0 → 100644
View file @
8fba84f8
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test case for YOLO detection dataloader configuration definition."""
from
absl.testing
import
parameterized
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.modeling
import
hyperparams
from
official.vision.beta.dataloaders
import
tfds_detection_decoders
from
official.vision.beta.projects.yolo.dataloaders
import
yolo_detection_input
@
dataclasses
.
dataclass
class
Parser
(
hyperparams
.
Config
):
"""Dummy configuration for parser."""
output_size
:
int
=
(
416
,
416
)
num_classes
:
int
=
80
fixed_size
:
bool
=
True
jitter_im
:
float
=
0.1
jitter_boxes
:
float
=
0.005
min_process_size
:
int
=
320
max_process_size
:
int
=
608
max_num_instances
:
int
=
200
random_flip
:
bool
=
True
seed
:
int
=
10
shuffle_buffer_size
:
int
=
10000
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
input_path
:
str
=
''
tfds_name
:
str
=
'coco/2017'
tfds_split
:
str
=
'train'
global_batch_size
:
int
=
10
is_training
:
bool
=
True
dtype
:
str
=
'float16'
decoder
=
None
parser
:
Parser
=
Parser
()
shuffle_buffer_size
:
int
=
10
tfds_download
:
bool
=
False
class
YoloDetectionInputTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
'training'
,
True
),
(
'testing'
,
False
))
def
test_yolo_input
(
self
,
is_training
):
params
=
DataConfig
(
is_training
=
is_training
)
decoder
=
tfds_detection_decoders
.
MSCOCODecoder
()
anchors
=
[[
12.0
,
19.0
],
[
31.0
,
46.0
],
[
96.0
,
54.0
],
[
46.0
,
114.0
],
[
133.0
,
127.0
],
[
79.0
,
225.0
],
[
301.0
,
150.0
],
[
172.0
,
286.0
],
[
348.0
,
340.0
]]
masks
=
{
'3'
:
[
0
,
1
,
2
],
'4'
:
[
3
,
4
,
5
],
'5'
:
[
6
,
7
,
8
]}
parser
=
yolo_detection_input
.
Parser
(
output_size
=
params
.
parser
.
output_size
,
num_classes
=
params
.
parser
.
num_classes
,
fixed_size
=
params
.
parser
.
fixed_size
,
jitter_im
=
params
.
parser
.
jitter_im
,
jitter_boxes
=
params
.
parser
.
jitter_boxes
,
min_process_size
=
params
.
parser
.
min_process_size
,
max_process_size
=
params
.
parser
.
max_process_size
,
max_num_instances
=
params
.
parser
.
max_num_instances
,
random_flip
=
params
.
parser
.
random_flip
,
seed
=
params
.
parser
.
seed
,
anchors
=
anchors
,
masks
=
masks
)
postprocess_fn
=
parser
.
postprocess_fn
(
is_training
=
is_training
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
None
).
batch
(
10
).
take
(
1
)
if
postprocess_fn
:
image
,
_
=
postprocess_fn
(
*
tf
.
data
.
experimental
.
get_single_element
(
dataset
))
else
:
image
,
_
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
)
print
(
image
.
shape
)
self
.
assertAllEqual
(
image
.
shape
,
(
10
,
10
,
416
,
416
,
3
))
self
.
assertTrue
(
tf
.
reduce_all
(
tf
.
math
.
logical_and
(
image
>=
0
,
image
<=
1
)))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/ops/__init__.py
0 → 100644
View file @
8fba84f8
official/vision/beta/projects/yolo/ops/box_ops.py
0 → 100644
View file @
8fba84f8
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bounding box utils."""
import
math
import
tensorflow
as
tf
def
yxyx_to_xcycwh
(
box
:
tf
.
Tensor
):
"""Converts boxes from ymin, xmin, ymax, xmax.
to x_center, y_center, width, height.
Args:
box: `Tensor` whose shape is [..., 4] and represents the coordinates
of boxes in ymin, xmin, ymax, xmax.
Returns:
`Tensor` whose shape is [..., 4] and contains the new format.
Raises:
ValueError: If the last dimension of box is not 4 or if box's dtype isn't
a floating point type.
"""
with
tf
.
name_scope
(
'yxyx_to_xcycwh'
):
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
split
(
box
,
4
,
axis
=-
1
)
x_center
=
(
xmax
+
xmin
)
/
2
y_center
=
(
ymax
+
ymin
)
/
2
width
=
xmax
-
xmin
height
=
ymax
-
ymin
box
=
tf
.
concat
([
x_center
,
y_center
,
width
,
height
],
axis
=-
1
)
return
box
def
xcycwh_to_yxyx
(
box
:
tf
.
Tensor
,
split_min_max
:
bool
=
False
):
"""Converts boxes from x_center, y_center, width, height.
to ymin, xmin, ymax, xmax.
Args:
box: a `Tensor` whose shape is [..., 4] and represents the coordinates
of boxes in x_center, y_center, width, height.
split_min_max: bool, whether or not to split x, y min and max values.
Returns:
box: a `Tensor` whose shape is [..., 4] and contains the new format.
Raises:
ValueError: If the last dimension of box is not 4 or if box's dtype isn't
a floating point type.
"""
with
tf
.
name_scope
(
'xcycwh_to_yxyx'
):
xy
,
wh
=
tf
.
split
(
box
,
2
,
axis
=-
1
)
xy_min
=
xy
-
wh
/
2
xy_max
=
xy
+
wh
/
2
x_min
,
y_min
=
tf
.
split
(
xy_min
,
2
,
axis
=-
1
)
x_max
,
y_max
=
tf
.
split
(
xy_max
,
2
,
axis
=-
1
)
box
=
tf
.
concat
([
y_min
,
x_min
,
y_max
,
x_max
],
axis
=-
1
)
if
split_min_max
:
box
=
tf
.
split
(
box
,
2
,
axis
=-
1
)
return
box
def
xcycwh_to_xyxy
(
box
:
tf
.
Tensor
,
split_min_max
:
bool
=
False
):
"""Converts boxes from x_center, y_center, width, height to.
xmin, ymin, xmax, ymax.
Args:
box: box: a `Tensor` whose shape is [..., 4] and represents the
coordinates of boxes in x_center, y_center, width, height.
split_min_max: bool, whether or not to split x, y min and max values.
Returns:
box: a `Tensor` whose shape is [..., 4] and contains the new format.
Raises:
ValueError: If the last dimension of box is not 4 or if box's dtype isn't
a floating point type.
"""
with
tf
.
name_scope
(
'xcycwh_to_yxyx'
):
xy
,
wh
=
tf
.
split
(
box
,
2
,
axis
=-
1
)
xy_min
=
xy
-
wh
/
2
xy_max
=
xy
+
wh
/
2
box
=
(
xy_min
,
xy_max
)
if
not
split_min_max
:
box
=
tf
.
concat
(
box
,
axis
=-
1
)
return
box
def
center_distance
(
center_1
:
tf
.
Tensor
,
center_2
:
tf
.
Tensor
):
"""Calculates the squared distance between two points.
This function is mathematically equivalent to the following code, but has
smaller rounding errors.
tf.norm(center_1 - center_2, axis=-1)**2
Args:
center_1: a `Tensor` whose shape is [..., 2] and represents a point.
center_2: a `Tensor` whose shape is [..., 2] and represents a point.
Returns:
dist: a `Tensor` whose shape is [...] and value represents the squared
distance between center_1 and center_2.
Raises:
ValueError: If the last dimension of either center_1 or center_2 is not 2.
"""
with
tf
.
name_scope
(
'center_distance'
):
dist
=
(
center_1
[...,
0
]
-
center_2
[...,
0
])
**
2
+
(
center_1
[...,
1
]
-
center_2
[...,
1
])
**
2
return
dist
def
compute_iou
(
box1
,
box2
,
yxyx
=
False
):
"""Calculates the intersection of union between box1 and box2.
Args:
box1: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
box2: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
yxyx: `bool`, whether or not box1, and box2 are in yxyx format.
Returns:
iou: a `Tensor` whose shape is [...] and value represents the intersection
over union.
Raises:
ValueError: If the last dimension of either box1 or box2 is not 4.
"""
# Get box corners
with
tf
.
name_scope
(
'iou'
):
if
not
yxyx
:
box1
=
xcycwh_to_yxyx
(
box1
)
box2
=
xcycwh_to_yxyx
(
box2
)
b1mi
,
b1ma
=
tf
.
split
(
box1
,
2
,
axis
=-
1
)
b2mi
,
b2ma
=
tf
.
split
(
box2
,
2
,
axis
=-
1
)
intersect_mins
=
tf
.
math
.
maximum
(
b1mi
,
b2mi
)
intersect_maxes
=
tf
.
math
.
minimum
(
b1ma
,
b2ma
)
intersect_wh
=
tf
.
math
.
maximum
(
intersect_maxes
-
intersect_mins
,
tf
.
zeros_like
(
intersect_mins
))
intersection
=
tf
.
reduce_prod
(
intersect_wh
,
axis
=-
1
)
# intersect_wh[..., 0] * intersect_wh[..., 1]
box1_area
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
b1ma
-
b1mi
,
axis
=-
1
))
box2_area
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
b2ma
-
b2mi
,
axis
=-
1
))
union
=
box1_area
+
box2_area
-
intersection
iou
=
intersection
/
(
union
+
1e-7
)
iou
=
tf
.
clip_by_value
(
iou
,
clip_value_min
=
0.0
,
clip_value_max
=
1.0
)
return
iou
def
compute_giou
(
box1
,
box2
):
"""Calculates the generalized intersection of union between box1 and box2.
Args:
box1: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
box2: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
Returns:
iou: a `Tensor` whose shape is [...] and value represents the generalized
intersection over union.
Raises:
ValueError: If the last dimension of either box1 or box2 is not 4.
"""
with
tf
.
name_scope
(
'giou'
):
# get box corners
box1
=
xcycwh_to_yxyx
(
box1
)
box2
=
xcycwh_to_yxyx
(
box2
)
# compute IOU
intersect_mins
=
tf
.
math
.
maximum
(
box1
[...,
0
:
2
],
box2
[...,
0
:
2
])
intersect_maxes
=
tf
.
math
.
minimum
(
box1
[...,
2
:
4
],
box2
[...,
2
:
4
])
intersect_wh
=
tf
.
math
.
maximum
(
intersect_maxes
-
intersect_mins
,
tf
.
zeros_like
(
intersect_mins
))
intersection
=
intersect_wh
[...,
0
]
*
intersect_wh
[...,
1
]
box1_area
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
box1
[...,
2
:
4
]
-
box1
[...,
0
:
2
],
axis
=-
1
))
box2_area
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
box2
[...,
2
:
4
]
-
box2
[...,
0
:
2
],
axis
=-
1
))
union
=
box1_area
+
box2_area
-
intersection
iou
=
tf
.
math
.
divide_no_nan
(
intersection
,
union
)
iou
=
tf
.
clip_by_value
(
iou
,
clip_value_min
=
0.0
,
clip_value_max
=
1.0
)
# find the smallest box to encompase both box1 and box2
c_mins
=
tf
.
math
.
minimum
(
box1
[...,
0
:
2
],
box2
[...,
0
:
2
])
c_maxes
=
tf
.
math
.
maximum
(
box1
[...,
2
:
4
],
box2
[...,
2
:
4
])
c
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
c_mins
-
c_maxes
,
axis
=-
1
))
# compute giou
giou
=
iou
-
tf
.
math
.
divide_no_nan
((
c
-
union
),
c
)
return
iou
,
giou
def
compute_diou
(
box1
,
box2
):
"""Calculates the distance intersection of union between box1 and box2.
Args:
box1: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
box2: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
Returns:
iou: a `Tensor` whose shape is [...] and value represents the distance
intersection over union.
Raises:
ValueError: If the last dimension of either box1 or box2 is not 4.
"""
with
tf
.
name_scope
(
'diou'
):
# compute center distance
dist
=
center_distance
(
box1
[...,
0
:
2
],
box2
[...,
0
:
2
])
# get box corners
box1
=
xcycwh_to_yxyx
(
box1
)
box2
=
xcycwh_to_yxyx
(
box2
)
# compute IOU
intersect_mins
=
tf
.
math
.
maximum
(
box1
[...,
0
:
2
],
box2
[...,
0
:
2
])
intersect_maxes
=
tf
.
math
.
minimum
(
box1
[...,
2
:
4
],
box2
[...,
2
:
4
])
intersect_wh
=
tf
.
math
.
maximum
(
intersect_maxes
-
intersect_mins
,
tf
.
zeros_like
(
intersect_mins
))
intersection
=
intersect_wh
[...,
0
]
*
intersect_wh
[...,
1
]
box1_area
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
box1
[...,
2
:
4
]
-
box1
[...,
0
:
2
],
axis
=-
1
))
box2_area
=
tf
.
math
.
abs
(
tf
.
reduce_prod
(
box2
[...,
2
:
4
]
-
box2
[...,
0
:
2
],
axis
=-
1
))
union
=
box1_area
+
box2_area
-
intersection
iou
=
tf
.
math
.
divide_no_nan
(
intersection
,
union
)
iou
=
tf
.
clip_by_value
(
iou
,
clip_value_min
=
0.0
,
clip_value_max
=
1.0
)
# compute max diagnal of the smallest enclosing box
c_mins
=
tf
.
math
.
minimum
(
box1
[...,
0
:
2
],
box2
[...,
0
:
2
])
c_maxes
=
tf
.
math
.
maximum
(
box1
[...,
2
:
4
],
box2
[...,
2
:
4
])
diag_dist
=
tf
.
reduce_sum
((
c_maxes
-
c_mins
)
**
2
,
axis
=-
1
)
regularization
=
tf
.
math
.
divide_no_nan
(
dist
,
diag_dist
)
diou
=
iou
+
regularization
return
iou
,
diou
def
compute_ciou
(
box1
,
box2
):
"""Calculates the complete intersection of union between box1 and box2.
Args:
box1: a `Tensor` whose shape is [..., 4] and represents the coordinates
of boxes in x_center, y_center, width, height.
box2: a `Tensor` whose shape is [..., 4] and represents the coordinates of
boxes in x_center, y_center, width, height.
Returns:
iou: a `Tensor` whose shape is [...] and value represents the complete
intersection over union.
Raises:
ValueError: If the last dimension of either box1 or box2 is not 4.
"""
with
tf
.
name_scope
(
'ciou'
):
# compute DIOU and IOU
iou
,
diou
=
compute_diou
(
box1
,
box2
)
# computer aspect ratio consistency
arcterm
=
(
tf
.
math
.
atan
(
tf
.
math
.
divide_no_nan
(
box1
[...,
2
],
box1
[...,
3
]))
-
tf
.
math
.
atan
(
tf
.
math
.
divide_no_nan
(
box2
[...,
2
],
box2
[...,
3
])))
**
2
v
=
4
*
arcterm
/
(
math
.
pi
)
**
2
# compute IOU regularization
a
=
tf
.
math
.
divide_no_nan
(
v
,
((
1
-
iou
)
+
v
))
ciou
=
diou
+
v
*
a
return
iou
,
ciou
official/vision/beta/projects/yolo/ops/box_ops_test.py
0 → 100644
View file @
8fba84f8
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.projects.yolo.ops
import
box_ops
class
InputUtilsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
((
1
),
(
4
))
def
test_box_conversions
(
self
,
num_boxes
):
boxes
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
num_boxes
,
4
))
expected_shape
=
np
.
array
([
num_boxes
,
4
])
xywh_box
=
box_ops
.
yxyx_to_xcycwh
(
boxes
)
yxyx_box
=
box_ops
.
xcycwh_to_yxyx
(
boxes
)
xyxy_box
=
box_ops
.
xcycwh_to_xyxy
(
boxes
)
self
.
assertAllEqual
(
tf
.
shape
(
xywh_box
).
numpy
(),
expected_shape
)
self
.
assertAllEqual
(
tf
.
shape
(
yxyx_box
).
numpy
(),
expected_shape
)
self
.
assertAllEqual
(
tf
.
shape
(
xyxy_box
).
numpy
(),
expected_shape
)
@
parameterized
.
parameters
((
1
),
(
5
),
(
7
))
def
test_ious
(
self
,
num_boxes
):
boxes
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
num_boxes
,
4
))
expected_shape
=
np
.
array
([
num_boxes
,
])
expected_iou
=
np
.
ones
([
num_boxes
,
])
iou
=
box_ops
.
compute_iou
(
boxes
,
boxes
)
_
,
giou
=
box_ops
.
compute_giou
(
boxes
,
boxes
)
_
,
ciou
=
box_ops
.
compute_ciou
(
boxes
,
boxes
)
_
,
diou
=
box_ops
.
compute_diou
(
boxes
,
boxes
)
self
.
assertAllEqual
(
tf
.
shape
(
iou
).
numpy
(),
expected_shape
)
self
.
assertArrayNear
(
iou
,
expected_iou
,
0.001
)
self
.
assertArrayNear
(
giou
,
expected_iou
,
0.001
)
self
.
assertArrayNear
(
ciou
,
expected_iou
,
0.001
)
self
.
assertArrayNear
(
diou
,
expected_iou
,
0.001
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/ops/preprocess_ops.py
0 → 100644
View file @
8fba84f8
This diff is collapsed.
Click to expand it.
official/vision/beta/projects/yolo/ops/preprocess_ops_test.py
0 → 100644
View file @
8fba84f8
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.projects.yolo.ops
import
preprocess_ops
class
PreprocessOpsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
((
416
,
416
,
5
,
300
,
300
),
(
100
,
200
,
6
,
50
,
50
))
def
test_resize_crop_filter
(
self
,
default_width
,
default_height
,
num_boxes
,
target_width
,
target_height
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
default_width
,
default_height
,
3
))
boxes
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
num_boxes
,
4
))
resized_image
,
resized_boxes
=
preprocess_ops
.
resize_crop_filter
(
image
,
boxes
,
default_width
,
default_height
,
target_width
,
target_height
)
resized_image_shape
=
tf
.
shape
(
resized_image
)
resized_boxes_shape
=
tf
.
shape
(
resized_boxes
)
self
.
assertAllEqual
([
default_height
,
default_width
,
3
],
resized_image_shape
.
numpy
())
self
.
assertAllEqual
([
num_boxes
,
4
],
resized_boxes_shape
.
numpy
())
@
parameterized
.
parameters
((
7
,
7.
,
5.
),
(
25
,
35.
,
45.
))
def
test_translate_boxes
(
self
,
num_boxes
,
translate_x
,
translate_y
):
boxes
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
num_boxes
,
4
))
translated_boxes
=
preprocess_ops
.
translate_boxes
(
boxes
,
translate_x
,
translate_y
)
translated_boxes_shape
=
tf
.
shape
(
translated_boxes
)
self
.
assertAllEqual
([
num_boxes
,
4
],
translated_boxes_shape
.
numpy
())
@
parameterized
.
parameters
((
100
,
200
,
75.
,
25.
),
(
400
,
600
,
25.
,
75.
))
def
test_translate_image
(
self
,
image_height
,
image_width
,
translate_x
,
translate_y
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
image_height
,
image_width
,
4
))
translated_image
=
preprocess_ops
.
translate_image
(
image
,
translate_x
,
translate_y
)
translated_image_shape
=
tf
.
shape
(
translated_image
)
self
.
assertAllEqual
([
image_height
,
image_width
,
4
],
translated_image_shape
.
numpy
())
@
parameterized
.
parameters
(([
1
,
2
],
20
,
0
),
([
13
,
2
,
4
],
15
,
0
))
def
test_pad_max_instances
(
self
,
input_shape
,
instances
,
pad_axis
):
expected_output_shape
=
input_shape
expected_output_shape
[
pad_axis
]
=
instances
output
=
preprocess_ops
.
pad_max_instances
(
np
.
ones
(
input_shape
),
instances
,
pad_axis
=
pad_axis
)
self
.
assertAllEqual
(
expected_output_shape
,
tf
.
shape
(
output
).
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment