Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8e9db0f1
"docs/vscode:/vscode.git/clone" did not exist on "326bbf034e90c51c33607ed251508f365bb1e6bb"
Commit
8e9db0f1
authored
Sep 10, 2021
by
Vishnu Banna
Browse files
loss function and init/run test
parent
a1df6e20
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1374 additions
and
0 deletions
+1374
-0
official/vision/beta/projects/yolo/losses/__init__.py
official/vision/beta/projects/yolo/losses/__init__.py
+0
-0
official/vision/beta/projects/yolo/losses/yolo_loss.py
official/vision/beta/projects/yolo/losses/yolo_loss.py
+654
-0
official/vision/beta/projects/yolo/losses/yolo_loss_test.py
official/vision/beta/projects/yolo/losses/yolo_loss_test.py
+103
-0
official/vision/beta/projects/yolo/ops/loss_utils.py
official/vision/beta/projects/yolo/ops/loss_utils.py
+617
-0
No files found.
official/vision/beta/projects/yolo/losses/__init__.py
0 → 100644
View file @
8e9db0f1
official/vision/beta/projects/yolo/losses/yolo_loss.py
0 → 100755
View file @
8e9db0f1
import
tensorflow
as
tf
from
collections
import
defaultdict
import
abc
from
functools
import
partial
from
official.vision.beta.projects.yolo.ops
import
(
loss_utils
,
box_ops
,
math_ops
)
class
YoloLossBase
(
object
,
metaclass
=
abc
.
ABCMeta
):
"""Parameters for the YOLO loss functions used at each detection
generator. This base class implements the base functionality required to
implement a Yolo Loss function"""
def
__init__
(
self
,
classes
,
mask
,
anchors
,
path_stride
=
1
,
ignore_thresh
=
0.7
,
truth_thresh
=
1.0
,
loss_type
=
"ciou"
,
iou_normalizer
=
1.0
,
cls_normalizer
=
1.0
,
obj_normalizer
=
1.0
,
label_smoothing
=
0.0
,
objectness_smooth
=
True
,
update_on_repeat
=
False
,
box_type
=
"original"
,
scale_x_y
=
1.0
,
max_delta
=
10
):
"""Loss Function Initialization.
Args:
classes: `int` for the number of classes
mask: `List[int]` for the output level that this specific model output
level
anchors: `List[List[int]]` for the anchor boxes that are used in the model
at all levels. For anchor free prediction set the anchor list to be the
same as the image resolution.
path_stride: `int` for how much to scale this level to get the orginal
input shape.
ignore_thresh: `float` for the IOU value over which the loss is not
propagated, and a detection is assumed to have been made.
truth_thresh: `float` for the IOU value over which the loss is propagated
despite a detection being made.
loss_type: `str` for the typeof iou loss to use with in {ciou, diou,
giou, iou}.
iou_normalizer: `float` for how much to scale the loss on the IOU or the
boxes.
cls_normalizer: `float` for how much to scale the loss on the classes.
obj_normalizer: `float` for how much to scale loss on the detection map.
objectness_smooth: `float` for how much to smooth the loss on the
detection map.
use_reduction_sum: `bool` for whether to use the scaled loss
or the traditional loss.
update_on_repeat: `bool` for whether to replace with the newest or the
best value when an index is consumed by multiple objects.
label_smoothing: `float` for how much to smooth the loss on the classes
box_type: `bool` for which scaling type to use.
scale_x_y: dictionary `float` values inidcating how far each pixel can see
outside of its containment of 1.0. a value of 1.2 indicates there is a
20% extended radius around each pixel that this specific pixel can
predict values for a center at. the center can range from 0 - value/2
to 1 + value/2, this value is set in the yolo filter, and resused here.
there should be one value for scale_xy for each level from min_level to
max_level.
max_delta: gradient clipping to apply to the box loss.
"""
self
.
_loss_type
=
loss_type
self
.
_classes
=
tf
.
constant
(
tf
.
cast
(
classes
,
dtype
=
tf
.
int32
))
self
.
_num
=
tf
.
cast
(
len
(
mask
),
dtype
=
tf
.
int32
)
self
.
_truth_thresh
=
truth_thresh
self
.
_ignore_thresh
=
ignore_thresh
self
.
_masks
=
mask
self
.
_anchors
=
anchors
self
.
_iou_normalizer
=
iou_normalizer
self
.
_cls_normalizer
=
cls_normalizer
self
.
_obj_normalizer
=
obj_normalizer
self
.
_scale_x_y
=
scale_x_y
self
.
_max_delta
=
max_delta
self
.
_label_smoothing
=
tf
.
cast
(
label_smoothing
,
tf
.
float32
)
self
.
_objectness_smooth
=
float
(
objectness_smooth
)
self
.
_update_on_repeat
=
update_on_repeat
self
.
_box_type
=
box_type
self
.
_path_stride
=
path_stride
box_kwargs
=
dict
(
stride
=
self
.
_path_stride
,
scale_xy
=
self
.
_scale_x_y
,
box_type
=
self
.
_box_type
,
max_delta
=
self
.
_max_delta
)
self
.
_decode_boxes
=
partial
(
loss_utils
.
get_predicted_box
,
**
box_kwargs
)
self
.
_build_per_path_attributes
()
def
box_loss
(
self
,
true_box
,
pred_box
,
darknet
=
False
):
"""Calls the iou functions and uses it to compute the loss this op is
the same regardless of Yolo Loss version"""
if
self
.
_loss_type
==
"giou"
:
iou
,
liou
=
box_ops
.
compute_giou
(
true_box
,
pred_box
)
elif
self
.
_loss_type
==
"ciou"
:
iou
,
liou
=
box_ops
.
compute_ciou
(
true_box
,
pred_box
,
darknet
=
darknet
)
else
:
liou
=
iou
=
box_ops
.
compute_iou
(
true_box
,
pred_box
)
loss_box
=
1
-
liou
return
iou
,
liou
,
loss_box
def
_tiled_global_box_search
(
self
,
pred_boxes
,
pred_classes
,
boxes
,
classes
,
true_conf
,
smoothed
,
scale
=
None
):
"""Completes a search of all predictions against all the ground truths to
dynamically associate ground truths with predictions."""
# Search all predictions against ground truths to find mathcing boxes for
# each pixel.
_
,
_
,
iou_max
,
_
=
self
.
_search_pairs
(
pred_boxes
,
pred_classes
,
boxes
,
classes
,
scale
=
scale
,
yxyx
=
True
)
# Find the exact indexes to ignore and keep.
ignore_mask
=
tf
.
cast
(
iou_max
<
self
.
_ignore_thresh
,
pred_boxes
.
dtype
)
iou_mask
=
iou_max
>
self
.
_ignore_thresh
if
not
smoothed
:
# Ignore all pixels where a box was not supposed to be predicted but a
# high confidence box was predicted.
obj_mask
=
true_conf
+
(
1
-
true_conf
)
*
ignore_mask
else
:
# Replace pixels in the tre confidence map with the max iou predicted
# with in that cell.
obj_mask
=
tf
.
ones_like
(
true_conf
)
iou_
=
(
1
-
self
.
_objectness_smooth
)
+
self
.
_objectness_smooth
*
iou_max
iou_
=
tf
.
where
(
iou_max
>
0
,
iou_
,
tf
.
zeros_like
(
iou_
))
true_conf
=
tf
.
where
(
iou_mask
,
iou_
,
true_conf
)
# Stop gradient so while loop is not tracked.
obj_mask
=
tf
.
stop_gradient
(
obj_mask
)
true_conf
=
tf
.
stop_gradient
(
true_conf
)
return
true_conf
,
obj_mask
def
__call__
(
self
,
true_counts
,
inds
,
y_true
,
boxes
,
classes
,
y_pred
):
"""Call function to compute the loss and return the total loss as
well as the loss for each detection mask on a given FPN level.
Args:
true_counts: `Tensor` of shape [batchsize, height, width, num_anchors]
represeneting how many boxes are in a given pixel [j, i] in the
output map.
inds: `Tensor` of shape [batchsize, None, 3] indicating the location
[j, i] that a given box is associatied with in the FPN prediction
map.
y_true: `Tensor` of shape [batchsize, None, 8] indicating the actual box
associated with each index in the inds tensor list.
boxes: `Tensor` of shape [batchsize, None, 4] indicating the original
ground truth boxes for each image as they came from the decoder used
for bounding box search.
classes: `Tensor` of shape [batchsize, None, 1] indicating the original
ground truth classes for each image as they came from the decoder used
for bounding box search.
y_pred: `Tensor` of shape [batchsize, height, width, output_depth]
holding the models output at a specific FPN level.
Return:
loss: `float` for the actual loss.
box_loss: `float` loss on the boxes used for metrics.
conf_loss: `float` loss on the confidence used for metrics.
class_loss: `float` loss on the classes used for metrics.
avg_iou: `float` metric for the average iou between predictions
and ground truth.
avg_obj: `float` metric for the average confidence of the model
for predictions.
"""
(
loss
,
box_loss
,
conf_loss
,
class_loss
,
mean_loss
,
iou
,
pred_conf
,
ind_mask
,
grid_mask
)
=
self
.
call
(
true_counts
,
inds
,
y_true
,
boxes
,
classes
,
y_pred
)
# Temporary metrics
box_loss
=
tf
.
stop_gradient
(
0.05
*
box_loss
/
self
.
_iou_normalizer
)
# Metric compute using done here to save time and resources.
sigmoid_conf
=
tf
.
stop_gradient
(
tf
.
sigmoid
(
pred_conf
))
iou
=
tf
.
stop_gradient
(
iou
)
avg_iou
=
loss_utils
.
avgiou
(
loss_utils
.
apply_mask
(
tf
.
squeeze
(
ind_mask
,
axis
=-
1
),
iou
))
avg_obj
=
loss_utils
.
avgiou
(
tf
.
squeeze
(
sigmoid_conf
,
axis
=-
1
)
*
grid_mask
)
return
(
loss
,
box_loss
,
conf_loss
,
class_loss
,
mean_loss
,
tf
.
stop_gradient
(
avg_iou
),
tf
.
stop_gradient
(
avg_obj
))
@
abc
.
abstractmethod
def
_build_per_path_attributes
(
self
):
"""Additional initialization required specifically for each unique YOLO
loss version"""
...
@
abc
.
abstractmethod
def
call
():
"""The actual logic to apply to the raw model for optimization."""
...
def
post_path_aggregation
(
self
,
loss
,
ground_truths
,
predictions
):
"""This method allows for post processing of a loss value after the loss
has been aggregateda across all the FPN levels."""
return
loss
@
abc
.
abstractmethod
def
cross_replica_aggregation
(
self
,
loss
,
num_replicas_in_sync
):
"""This controls how the loss should be aggregated across replicas."""
...
@
tf
.
custom_gradient
def
grad_sigmoid
(
values
):
# This is an identity operation that will
# allow us to add some steps to the back propagation.
def
delta
(
dy
):
# Darknet only propagtes sigmoids for the boxes
# under some conditions, so we need this to selectively
# add the sigmoid to the chain rule
t
=
tf
.
math
.
sigmoid
(
values
)
return
dy
*
t
*
(
1
-
t
)
return
values
,
delta
class
DarknetLoss
(
YoloLossBase
):
"""This class implements the full logic for the standard Yolo models
encompassing Yolov3, Yolov4, and Yolo-Tiny."""
def
_build_per_path_attributes
(
self
):
"""Paramterization of pair wise search and grid generators for box
decoding and dynamic ground truth association."""
self
.
_anchor_generator
=
loss_utils
.
GridGenerator
(
masks
=
self
.
_masks
,
anchors
=
self
.
_anchors
,
scale_anchors
=
self
.
_path_stride
)
if
self
.
_ignore_thresh
>
0.0
:
self
.
_search_pairs
=
loss_utils
.
PairWiseSearch
(
iou_type
=
"iou"
,
any
=
True
,
min_conf
=
0.25
)
return
def
call
(
self
,
true_counts
,
inds
,
y_true
,
boxes
,
classes
,
y_pred
):
"""Per FPN path loss computation logic."""
if
self
.
_box_type
==
"scaled"
:
# Darknet Model Propagates a sigmoid once in back prop so we replicate
# that behaviour
y_pred
=
grad_sigmoid
(
y_pred
)
# Generate and store constants and format output.
shape
=
tf
.
shape
(
true_counts
)
batch_size
,
width
,
height
,
num
=
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]
fwidth
=
tf
.
cast
(
width
,
tf
.
float32
)
fheight
=
tf
.
cast
(
height
,
tf
.
float32
)
grid_points
,
anchor_grid
=
self
.
_anchor_generator
(
width
,
height
,
batch_size
,
dtype
=
tf
.
float32
)
# Cast all input compontnts to float32 and stop gradient to save memory.
boxes
=
tf
.
stop_gradient
(
tf
.
cast
(
boxes
,
tf
.
float32
))
classes
=
tf
.
stop_gradient
(
tf
.
cast
(
classes
,
tf
.
float32
))
y_true
=
tf
.
stop_gradient
(
tf
.
cast
(
y_true
,
tf
.
float32
))
true_counts
=
tf
.
stop_gradient
(
tf
.
cast
(
true_counts
,
tf
.
float32
))
true_conf
=
tf
.
stop_gradient
(
tf
.
clip_by_value
(
true_counts
,
0.0
,
1.0
))
grid_points
=
tf
.
stop_gradient
(
grid_points
)
anchor_grid
=
tf
.
stop_gradient
(
anchor_grid
)
# Split all the ground truths to use as seperate items in loss computation.
(
true_box
,
ind_mask
,
true_class
,
_
,
_
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
,
1
,
1
],
axis
=-
1
)
true_conf
=
tf
.
squeeze
(
true_conf
,
axis
=-
1
)
true_class
=
tf
.
squeeze
(
true_class
,
axis
=-
1
)
grid_mask
=
true_conf
# Splits all predictions.
y_pred
=
tf
.
cast
(
tf
.
reshape
(
y_pred
,
[
batch_size
,
width
,
height
,
num
,
-
1
]),
tf
.
float32
)
pred_box
,
pred_conf
,
pred_class
=
tf
.
split
(
y_pred
,
[
4
,
1
,
-
1
],
axis
=-
1
)
# Decode the boxes to be used for loss compute.
_
,
_
,
pred_box
=
self
.
_decode_boxes
(
fwidth
,
fheight
,
pred_box
,
anchor_grid
,
grid_points
,
darknet
=
True
)
# If the ignore threshold is enabled, search all boxes ignore all
# IOU valeus larger than the ignore threshold that are not in the
# noted ground truth list.
if
self
.
_ignore_thresh
!=
0.0
:
(
true_conf
,
obj_mask
)
=
self
.
_tiled_global_box_search
(
pred_box
,
tf
.
stop_gradient
(
tf
.
sigmoid
(
pred_class
)),
boxes
,
classes
,
true_conf
,
smoothed
=
self
.
_objectness_smooth
>
0
)
# Build the one hot class list that are used for class loss.
true_class
=
tf
.
one_hot
(
tf
.
cast
(
true_class
,
tf
.
int32
),
depth
=
tf
.
shape
(
pred_class
)[
-
1
],
dtype
=
pred_class
.
dtype
)
true_classes
=
tf
.
stop_gradient
(
loss_utils
.
apply_mask
(
ind_mask
,
true_class
))
# Reorganize the one hot class list as a grid.
true_class
=
loss_utils
.
build_grid
(
inds
,
true_classes
,
pred_class
,
ind_mask
,
update
=
False
)
true_class
=
tf
.
stop_gradient
(
true_class
)
# Use the class mask to find the number of objects located in
# each predicted grid cell/pixel.
counts
=
true_class
counts
=
tf
.
reduce_sum
(
counts
,
axis
=-
1
,
keepdims
=
True
)
reps
=
tf
.
gather_nd
(
counts
,
inds
,
batch_dims
=
1
)
reps
=
tf
.
squeeze
(
reps
,
axis
=-
1
)
reps
=
tf
.
stop_gradient
(
tf
.
where
(
reps
==
0.0
,
tf
.
ones_like
(
reps
),
reps
))
# Compute the loss for only the cells in which the boxes are located.
pred_box
=
loss_utils
.
apply_mask
(
ind_mask
,
tf
.
gather_nd
(
pred_box
,
inds
,
batch_dims
=
1
))
iou
,
_
,
box_loss
=
self
.
box_loss
(
true_box
,
pred_box
,
darknet
=
True
)
box_loss
=
loss_utils
.
apply_mask
(
tf
.
squeeze
(
ind_mask
,
axis
=-
1
),
box_loss
)
box_loss
=
math_ops
.
divide_no_nan
(
box_loss
,
reps
)
box_loss
=
tf
.
cast
(
tf
.
reduce_sum
(
box_loss
,
axis
=
1
),
dtype
=
y_pred
.
dtype
)
# Compute the sigmoid binary cross entropy for the class maps.
class_loss
=
tf
.
reduce_mean
(
loss_utils
.
sigmoid_BCE
(
tf
.
expand_dims
(
true_class
,
axis
=-
1
),
tf
.
expand_dims
(
pred_class
,
axis
=-
1
),
self
.
_label_smoothing
),
axis
=-
1
)
# Apply normalization to the class losses.
if
self
.
_cls_normalizer
<
1.0
:
# Build a mask based on the true class locations.
cls_norm_mask
=
true_class
# Apply the classes weight to class indexes were one_hot is one.
class_loss
*=
((
1
-
cls_norm_mask
)
+
cls_norm_mask
*
self
.
_cls_normalizer
)
# Mask to the class loss and compute the sum over all the objects.
class_loss
=
tf
.
reduce_sum
(
class_loss
,
axis
=-
1
)
class_loss
=
loss_utils
.
apply_mask
(
grid_mask
,
class_loss
)
class_loss
=
math_ops
.
rm_nan_inf
(
class_loss
,
val
=
0.0
)
class_loss
=
tf
.
cast
(
tf
.
reduce_sum
(
class_loss
,
axis
=
(
1
,
2
,
3
)),
dtype
=
y_pred
.
dtype
)
# Compute the sigmoid binary cross entropy for the confidence maps.
bce
=
tf
.
reduce_mean
(
loss_utils
.
sigmoid_BCE
(
tf
.
expand_dims
(
true_conf
,
axis
=-
1
),
pred_conf
,
0.0
),
axis
=-
1
)
# Mask the confidence loss and take the sum across all the grid cells.
if
self
.
_ignore_thresh
!=
0.0
:
bce
=
loss_utils
.
apply_mask
(
obj_mask
,
bce
)
conf_loss
=
tf
.
cast
(
tf
.
reduce_sum
(
bce
,
axis
=
(
1
,
2
,
3
)),
dtype
=
y_pred
.
dtype
)
# Apply the weights to each loss.
box_loss
*=
self
.
_iou_normalizer
conf_loss
*=
self
.
_obj_normalizer
# Add all the losses together then take the mean over the batches.
loss
=
box_loss
+
class_loss
+
conf_loss
loss
=
tf
.
reduce_mean
(
loss
)
# Reduce the mean of the losses to use as a metric.
box_loss
=
tf
.
reduce_mean
(
box_loss
)
conf_loss
=
tf
.
reduce_mean
(
conf_loss
)
class_loss
=
tf
.
reduce_mean
(
class_loss
)
return
(
loss
,
box_loss
,
conf_loss
,
class_loss
,
loss
,
iou
,
pred_conf
,
ind_mask
,
grid_mask
)
def
cross_replica_aggregation
(
self
,
loss
,
num_replicas_in_sync
):
"""this method is not specific to each loss path, but each loss type"""
return
loss
/
num_replicas_in_sync
class
ScaledLoss
(
YoloLossBase
):
"""This class implements the full logic for the scaled Yolo models
encompassing Yolov4-csp, Yolov4-Large, and Yolov5."""
def
_build_per_path_attributes
(
self
):
"""Paramterization of pair wise search and grid generators for box
decoding and dynamic ground truth association."""
self
.
_anchor_generator
=
loss_utils
.
GridGenerator
(
masks
=
self
.
_masks
,
anchors
=
self
.
_anchors
,
scale_anchors
=
self
.
_path_stride
)
if
self
.
_ignore_thresh
>
0.0
:
self
.
_search_pairs
=
loss_utils
.
PairWiseSearch
(
iou_type
=
self
.
_loss_type
,
any
=
False
,
min_conf
=
0.25
)
return
def
call
(
self
,
true_counts
,
inds
,
y_true
,
boxes
,
classes
,
y_pred
):
"""Per FPN path loss computation logic."""
# Generate shape constants.
shape
=
tf
.
shape
(
true_counts
)
batch_size
,
width
,
height
,
num
=
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]
fwidth
=
tf
.
cast
(
width
,
tf
.
float32
)
fheight
=
tf
.
cast
(
height
,
tf
.
float32
)
# Cast all input compontnts to float32 and stop gradient to save memory.
y_true
=
tf
.
cast
(
y_true
,
tf
.
float32
)
true_counts
=
tf
.
cast
(
true_counts
,
tf
.
float32
)
true_conf
=
tf
.
clip_by_value
(
true_counts
,
0.0
,
1.0
)
grid_points
,
anchor_grid
=
self
.
_anchor_generator
(
width
,
height
,
batch_size
,
dtype
=
tf
.
float32
)
# Split the y_true list.
(
true_box
,
ind_mask
,
true_class
,
_
,
_
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
,
1
,
1
],
axis
=-
1
)
grid_mask
=
true_conf
=
tf
.
squeeze
(
true_conf
,
axis
=-
1
)
true_class
=
tf
.
squeeze
(
true_class
,
axis
=-
1
)
num_objs
=
tf
.
cast
(
tf
.
reduce_sum
(
ind_mask
),
dtype
=
y_pred
.
dtype
)
# Split up the predicitons.
y_pred
=
tf
.
cast
(
tf
.
reshape
(
y_pred
,
[
batch_size
,
width
,
height
,
num
,
-
1
]),
tf
.
float32
)
pred_box
,
pred_conf
,
pred_class
=
tf
.
split
(
y_pred
,
[
4
,
1
,
-
1
],
axis
=-
1
)
# Decode the boxes for loss compute.
scale
,
pred_box
,
_
=
self
.
_decode_boxes
(
fwidth
,
fheight
,
pred_box
,
anchor_grid
,
grid_points
,
darknet
=
False
)
# If the ignore threshold is enabled, search all boxes ignore all
# IOU valeus larger than the ignore threshold that are not in the
# noted ground truth list.
if
self
.
_ignore_thresh
!=
0.0
:
(
_
,
obj_mask
)
=
self
.
_tiled_global_box_search
(
pred_box
,
tf
.
stop_gradient
(
tf
.
sigmoid
(
pred_class
)),
boxes
,
classes
,
true_conf
,
smoothed
=
False
,
scale
=
scale
)
# Scale and shift and select the ground truth boxes
# and predictions to the prediciton domain.
offset
=
tf
.
cast
(
tf
.
gather_nd
(
grid_points
,
inds
,
batch_dims
=
1
),
true_box
.
dtype
)
offset
=
tf
.
concat
([
offset
,
tf
.
zeros_like
(
offset
)],
axis
=-
1
)
true_box
=
loss_utils
.
apply_mask
(
ind_mask
,
(
scale
*
true_box
)
-
offset
)
pred_box
=
loss_utils
.
apply_mask
(
ind_mask
,
tf
.
gather_nd
(
pred_box
,
inds
,
batch_dims
=
1
))
# Select the correct/used prediction classes.
true_class
=
tf
.
one_hot
(
tf
.
cast
(
true_class
,
tf
.
int32
),
depth
=
tf
.
shape
(
pred_class
)[
-
1
],
dtype
=
pred_class
.
dtype
)
true_class
=
loss_utils
.
apply_mask
(
ind_mask
,
true_class
)
pred_class
=
loss_utils
.
apply_mask
(
ind_mask
,
tf
.
gather_nd
(
pred_class
,
inds
,
batch_dims
=
1
))
# Compute the box loss.
_
,
iou
,
box_loss
=
self
.
box_loss
(
true_box
,
pred_box
,
darknet
=
False
)
box_loss
=
loss_utils
.
apply_mask
(
tf
.
squeeze
(
ind_mask
,
axis
=-
1
),
box_loss
)
box_loss
=
math_ops
.
divide_no_nan
(
tf
.
reduce_sum
(
box_loss
),
num_objs
)
# Use the box IOU to build the map for confidence loss computation.
iou
=
tf
.
maximum
(
tf
.
stop_gradient
(
iou
),
0.0
)
smoothed_iou
=
((
(
1
-
self
.
_objectness_smooth
)
*
tf
.
cast
(
ind_mask
,
iou
.
dtype
))
+
self
.
_objectness_smooth
*
tf
.
expand_dims
(
iou
,
axis
=-
1
))
smoothed_iou
=
loss_utils
.
apply_mask
(
ind_mask
,
smoothed_iou
)
true_conf
=
loss_utils
.
build_grid
(
inds
,
smoothed_iou
,
pred_conf
,
ind_mask
,
update
=
self
.
_update_on_repeat
)
true_conf
=
tf
.
squeeze
(
true_conf
,
axis
=-
1
)
# Compute the cross entropy loss for the confidence map.
bce
=
tf
.
keras
.
losses
.
binary_crossentropy
(
tf
.
expand_dims
(
true_conf
,
axis
=-
1
),
pred_conf
,
from_logits
=
True
)
if
self
.
_ignore_thresh
!=
0.0
:
bce
=
loss_utils
.
apply_mask
(
obj_mask
,
bce
)
conf_loss
=
tf
.
reduce_mean
(
bce
)
# Compute the cross entropy loss for the class maps.
class_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
true_class
,
pred_class
,
label_smoothing
=
self
.
_label_smoothing
,
from_logits
=
True
)
class_loss
=
loss_utils
.
apply_mask
(
tf
.
squeeze
(
ind_mask
,
axis
=-
1
),
class_loss
)
class_loss
=
math_ops
.
divide_no_nan
(
tf
.
reduce_sum
(
class_loss
),
num_objs
)
# Apply the weights to each loss.
box_loss
*=
self
.
_iou_normalizer
class_loss
*=
self
.
_cls_normalizer
conf_loss
*=
self
.
_obj_normalizer
# Add all the losses together then take the sum over the batches.
mean_loss
=
box_loss
+
class_loss
+
conf_loss
loss
=
mean_loss
*
tf
.
cast
(
batch_size
,
mean_loss
.
dtype
)
return
(
loss
,
box_loss
,
conf_loss
,
class_loss
,
mean_loss
,
iou
,
pred_conf
,
ind_mask
,
grid_mask
)
def
post_path_aggregation
(
self
,
loss
,
ground_truths
,
predictions
):
scale
=
tf
.
stop_gradient
(
3
/
len
(
list
(
predictions
.
keys
())))
return
loss
*
scale
def
cross_replica_aggregation
(
self
,
loss
,
num_replicas_in_sync
):
"""this method is not specific to each loss path, but each loss type"""
return
loss
LOSSES
=
{
"darknet"
:
DarknetLoss
,
"scaled"
:
ScaledLoss
}
class
YoloLoss
(
object
):
"""This class implements the aggregated loss across paths for the YOLO
model. The class implements the YOLO loss as a factory in order to allow
selection and implementation of new versions of the YOLO loss as the model
is updated in the future.
"""
def
__init__
(
self
,
keys
,
classes
,
anchors
,
masks
=
None
,
path_strides
=
None
,
truth_thresholds
=
None
,
ignore_thresholds
=
None
,
loss_types
=
None
,
iou_normalizers
=
None
,
cls_normalizers
=
None
,
obj_normalizers
=
None
,
objectness_smooths
=
None
,
box_types
=
None
,
scale_xys
=
None
,
max_deltas
=
None
,
label_smoothing
=
0.0
,
use_scaled_loss
=
False
,
update_on_repeat
=
True
):
"""Loss Function Initialization.
Args:
keys: `List[str]` indicating the name of the FPN paths that need to be
optimized.
classes: `int` for the number of classes
anchors: `List[List[int]]` for the anchor boxes that are used in the model
at all levels. For anchor free prediction set the anchor list to be the
same as the image resolution.
masks: `List[int]` for the output level that this specific model output
level
path_strides: `Dict[int]` for how much to scale this level to get the
orginal input shape for each FPN path.
truth_thresholds: `Dict[float]` for the IOU value over which the loss is
propagated despite a detection being made for each FPN path.
ignore_thresholds: `Dict[float]` for the IOU value over which the loss is
not propagated, and a detection is assumed to have been made for each
FPN path.
loss_types: `Dict[str]` for the typeof iou loss to use with in {ciou,
diou, giou, iou} for each FPN path.
iou_normalizers: `Dict[float]` for how much to scale the loss on the IOU
or the boxes for each FPN path.
cls_normalizers: `Dict[float]` for how much to scale the loss on the
classes for each FPN path.
obj_normalizers: `Dict[float]` for how much to scale loss on the detection
map for each FPN path.
objectness_smooths: `Dict[float]` for how much to smooth the loss on the
detection map for each FPN path.
box_type: `Dict[bool]` for which scaling type to use for each FPN path.
scale_xys: `Dict[float]` values inidcating how far each pixel can see
outside of its containment of 1.0. a value of 1.2 indicates there is a
20% extended radius around each pixel that this specific pixel can
predict values for a center at. the center can range from 0 - value/2
to 1 + value/2, this value is set in the yolo filter, and resused here.
there should be one value for scale_xy for each level from min_level to
max_level. One for each FPN path.
max_deltas: `Dict[float]` for gradient clipping to apply to the box loss
for each FPN path.
label_smoothing: `Dict[float]` for how much to smooth the loss on the
classes for each FPN path.
use_scaled_loss: `bool` for whether to use the scaled loss
or the traditional loss.
update_on_repeat: `bool` for whether to replace with the newest or
the best value when an index is consumed by multiple objects.
"""
if
use_scaled_loss
:
loss_type
=
"scaled"
else
:
loss_type
=
"darknet"
self
.
_loss_dict
=
{}
for
key
in
keys
:
self
.
_loss_dict
[
key
]
=
LOSSES
[
loss_type
](
classes
=
classes
,
anchors
=
anchors
,
mask
=
masks
[
key
],
truth_thresh
=
truth_thresholds
[
key
],
ignore_thresh
=
ignore_thresholds
[
key
],
loss_type
=
loss_types
[
key
],
iou_normalizer
=
iou_normalizers
[
key
],
cls_normalizer
=
cls_normalizers
[
key
],
obj_normalizer
=
obj_normalizers
[
key
],
box_type
=
box_types
[
key
],
objectness_smooth
=
objectness_smooths
[
key
],
max_delta
=
max_deltas
[
key
],
path_stride
=
path_strides
[
key
],
scale_x_y
=
scale_xys
[
key
],
update_on_repeat
=
update_on_repeat
,
label_smoothing
=
label_smoothing
)
def
__call__
(
self
,
ground_truth
,
predictions
,
use_reduced_logs
=
True
):
metric_dict
=
defaultdict
(
dict
)
metric_dict
[
'net'
][
'box'
]
=
0
metric_dict
[
'net'
][
'class'
]
=
0
metric_dict
[
'net'
][
'conf'
]
=
0
loss_val
,
metric_loss
=
0
,
0
num_replicas_in_sync
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
for
key
in
predictions
.
keys
():
(
_loss
,
_loss_box
,
_loss_conf
,
_loss_class
,
_mean_loss
,
_avg_iou
,
_avg_obj
)
=
self
.
_loss_dict
[
key
](
ground_truth
[
'true_conf'
][
key
],
ground_truth
[
'inds'
][
key
],
ground_truth
[
'upds'
][
key
],
ground_truth
[
'bbox'
],
ground_truth
[
'classes'
],
predictions
[
key
])
# after computing the loss, scale loss as needed for aggregation
# across FPN levels
_loss
=
self
.
_loss_dict
[
key
].
post_path_aggregation
(
_loss
,
ground_truth
,
predictions
)
# after completing the scaling of the loss on each replica, handle
# scaling the loss for mergeing the loss across replicas
_loss
=
self
.
_loss_dict
[
key
].
cross_replica_aggregation
(
_loss
,
num_replicas_in_sync
)
loss_val
+=
_loss
# detach all the below gradients: none of them should make a
# contribution to the gradient form this point forwards
metric_loss
+=
tf
.
stop_gradient
(
_mean_loss
)
metric_dict
[
key
][
'loss'
]
=
tf
.
stop_gradient
(
_mean_loss
)
metric_dict
[
key
][
'avg_iou'
]
=
tf
.
stop_gradient
(
_avg_iou
)
metric_dict
[
key
][
"avg_obj"
]
=
tf
.
stop_gradient
(
_avg_obj
)
if
not
use_reduced_logs
:
metric_dict
[
key
][
'conf_loss'
]
=
tf
.
stop_gradient
(
_loss_conf
)
metric_dict
[
key
][
'box_loss'
]
=
tf
.
stop_gradient
(
_loss_box
)
metric_dict
[
key
][
'class_loss'
]
=
tf
.
stop_gradient
(
_loss_class
)
metric_dict
[
'net'
][
'box'
]
+=
tf
.
stop_gradient
(
_loss_box
)
metric_dict
[
'net'
][
'class'
]
+=
tf
.
stop_gradient
(
_loss_class
)
metric_dict
[
'net'
][
'conf'
]
+=
tf
.
stop_gradient
(
_loss_conf
)
return
loss_val
,
metric_loss
,
metric_dict
official/vision/beta/projects/yolo/losses/yolo_loss_test.py
0 → 100755
View file @
8e9db0f1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for yolo heads."""
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.yolo.losses
import
yolo_loss
class
YoloDecoderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_loss_init
(
self
):
"""Test creation of YOLO family models."""
def
inpdict
(
input_shape
,
dtype
=
tf
.
float32
):
inputs
=
{}
for
key
in
input_shape
:
inputs
[
key
]
=
tf
.
ones
(
input_shape
[
key
],
dtype
=
dtype
)
return
inputs
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_shape
=
{
'3'
:
[
1
,
52
,
52
,
255
],
'4'
:
[
1
,
26
,
26
,
255
],
'5'
:
[
1
,
13
,
13
,
255
]
}
classes
=
80
bps
=
3
masks
=
{
'3'
:
[
0
,
1
,
2
],
'4'
:
[
3
,
4
,
5
],
'5'
:
[
6
,
7
,
8
]}
anchors
=
[[
12.0
,
19.0
],
[
31.0
,
46.0
],
[
96.0
,
54.0
],
[
46.0
,
114.0
],
[
133.0
,
127.0
],
[
79.0
,
225.0
],
[
301.0
,
150.0
],
[
172.0
,
286.0
],
[
348.0
,
340.0
]]
box_type
=
{
key
:
"scaled"
for
key
in
masks
.
keys
()}
keys
=
[
'3'
,
'4'
,
'5'
]
path_strides
=
{
key
:
2
**
int
(
key
)
for
key
in
keys
}
loss
=
yolo_loss
.
YoloLoss
(
keys
,
classes
,
anchors
,
masks
=
masks
,
path_strides
=
path_strides
,
truth_thresholds
=
{
key
:
1.0
for
key
in
keys
},
ignore_thresholds
=
{
key
:
0.7
for
key
in
keys
},
loss_types
=
{
key
:
"ciou"
for
key
in
keys
},
iou_normalizers
=
{
key
:
0.05
for
key
in
keys
},
cls_normalizers
=
{
key
:
0.5
for
key
in
keys
},
obj_normalizers
=
{
key
:
1.0
for
key
in
keys
},
objectness_smooths
=
{
key
:
1.0
for
key
in
keys
},
box_types
=
{
key
:
"scaled"
for
key
in
keys
},
scale_xys
=
{
key
:
2.0
for
key
in
keys
},
max_deltas
=
{
key
:
30.0
for
key
in
keys
},
label_smoothing
=
0.0
,
use_scaled_loss
=
True
,
update_on_repeat
=
True
)
count
=
inpdict
({
'3'
:
[
1
,
52
,
52
,
3
,
1
],
'4'
:
[
1
,
26
,
26
,
3
,
1
],
'5'
:
[
1
,
13
,
13
,
3
,
1
]
})
ind
=
inpdict
({
'3'
:
[
1
,
300
,
3
],
'4'
:
[
1
,
300
,
3
],
'5'
:
[
1
,
300
,
3
]
},
tf
.
int32
)
truths
=
inpdict
({
'3'
:
[
1
,
300
,
8
],
'4'
:
[
1
,
300
,
8
],
'5'
:
[
1
,
300
,
8
]
})
boxes
=
tf
.
ones
([
1
,
300
,
4
],
dtype
=
tf
.
float32
)
classes
=
tf
.
ones
([
1
,
300
],
dtype
=
tf
.
float32
)
gt
=
{
"true_conf"
:
count
,
"inds"
:
ind
,
"upds"
:
truths
,
"bbox"
:
boxes
,
"classes"
:
classes
}
loss_val
,
metric_loss
,
metric_dict
=
loss
(
gt
,
inpdict
(
input_shape
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/ops/loss_utils.py
0 → 100755
View file @
8e9db0f1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Yolo loss utility functions."""
import
tensorflow
as
tf
import
numpy
as
np
from
official.vision.beta.projects.yolo.ops
import
(
box_ops
,
math_ops
)
@
tf
.
custom_gradient
def
sigmoid_BCE
(
y
,
x_prime
,
label_smoothing
):
"""Applies the Sigmoid Cross Entropy Loss Using the same derivative as that
found in the Darknet C library. The derivative of this method is not the same
as the standard binary cross entropy with logits function.
The BCE with logits function equation is as follows:
x = 1 / (1 + exp(-x_prime))
bce = -ylog(x) - (1 - y)log(1 - x)
The standard BCE with logits function derivative is as follows:
dloss = -y/x + (1-y)/(1-x)
dsigmoid = x * (1 - x)
dx = dloss * dsigmoid
This derivative can be reduced simply to:
dx = (-y + x)
This simplification is used by the darknet library in order to improve
training stability. The gradient is almost the same
as tf.keras.losses.binary_crossentropy but varies slightly and
yields different performance.
Args:
y: `Tensor` holding ground truth data.
x_prime: `Tensor` holding the predictions prior to application of the
sigmoid operation.
label_smoothing: float value between 0.0 and 1.0 indicating the amount of
smoothing to apply to the data.
Returns:
bce: Tensor of the be applied loss values.
delta: callable function indicating the custom gradient for this operation.
"""
eps
=
1e-9
x
=
tf
.
math
.
sigmoid
(
x_prime
)
y
=
tf
.
stop_gradient
(
y
*
(
1
-
label_smoothing
)
+
0.5
*
label_smoothing
)
bce
=
-
y
*
tf
.
math
.
log
(
x
+
eps
)
-
(
1
-
y
)
*
tf
.
math
.
log
(
1
-
x
+
eps
)
def
delta
(
dpass
):
x
=
tf
.
math
.
sigmoid
(
x_prime
)
dx
=
(
-
y
+
x
)
*
dpass
dy
=
tf
.
zeros_like
(
y
)
return
dy
,
dx
,
0.0
return
bce
,
delta
def
apply_mask
(
mask
,
x
,
value
=
0
):
"""This function is used for gradient masking. The YOLO loss function makes
extensive use of dynamically shaped tensors. To allow this use case on the
TPU while preserving the gradient correctly for back propagation we use this
masking function to use a tf.where operation to hard set masked location to
have a gradient and a value of zero.
Args:
mask: A `Tensor` with the same shape as x used to select values of
importance.
x: A `Tensor` with the same shape as mask that will be getting masked.
Returns:
x: A masked `Tensor` with the same shape as x.
"""
mask
=
tf
.
cast
(
mask
,
tf
.
bool
)
masked
=
tf
.
where
(
mask
,
x
,
tf
.
zeros_like
(
x
)
+
value
)
return
masked
def
build_grid
(
indexes
,
truths
,
preds
,
ind_mask
,
update
=
False
,
grid
=
None
):
"""This function is used to broadcast all the indexes to the correct
ground truth mask, used for iou detection map in the scaled loss and
the classification mask in the darknet loss.
Args:
indexes: A `Tensor` for the indexes
truths: A `Tensor` for the ground truth.
preds: A `Tensor` for the predictions.
ind_mask: A `Tensor` for the index masks.
update: A `bool` for updating the grid.
grid: A `Tensor` for the grid.
Returns:
grid: A `Tensor` representing the augmented grid.
"""
# this function is used to broadcast all the indexes to the correct
# into the correct ground truth mask, used for iou detection map
# in the scaled loss and the classification mask in the darknet loss
num_flatten
=
tf
.
shape
(
preds
)[
-
1
]
# is there a way to verify that we are not on the CPU?
ind_mask
=
tf
.
cast
(
ind_mask
,
indexes
.
dtype
)
# find all the batch indexes using the cumulated sum of a ones tensor
# cumsum(ones) - 1 yeild the zero indexed batches
bhep
=
tf
.
reduce_max
(
tf
.
ones_like
(
indexes
),
axis
=-
1
,
keepdims
=
True
)
bhep
=
tf
.
math
.
cumsum
(
bhep
,
axis
=
0
)
-
1
# concatnate the batch sizes to the indexes
indexes
=
tf
.
concat
([
bhep
,
indexes
],
axis
=-
1
)
indexes
=
apply_mask
(
tf
.
cast
(
ind_mask
,
indexes
.
dtype
),
indexes
)
indexes
=
(
indexes
+
(
ind_mask
-
1
))
# reshape the indexes into the correct shape for the loss,
# just flatten all indexes but the last
indexes
=
tf
.
reshape
(
indexes
,
[
-
1
,
4
])
# also flatten the ground truth value on all axis but the last
truths
=
tf
.
reshape
(
truths
,
[
-
1
,
num_flatten
])
# build a zero grid in the samve shape as the predicitons
if
grid
is
None
:
grid
=
tf
.
zeros_like
(
preds
)
# remove invalid values from the truths that may have
# come up from computation, invalid = nan and inf
truths
=
math_ops
.
rm_nan_inf
(
truths
)
# scatter update the zero grid
if
update
:
grid
=
tf
.
tensor_scatter_nd_update
(
grid
,
indexes
,
truths
)
else
:
grid
=
tf
.
tensor_scatter_nd_max
(
grid
,
indexes
,
truths
)
# stop gradient and return to avoid TPU errors and save compute
# resources
return
grid
class
GridGenerator
(
object
):
"""Grid generator that generates anchor grids that will be used
in to decode the predicted boxes."""
def
__init__
(
self
,
anchors
,
masks
=
None
,
scale_anchors
=
None
):
"""Initialize Grid Generator
Args:
anchors: A `List[List[int]]` for the anchor boxes that are used in the
model at all levels.
mask: A `List[int]` for the output level that this specific model output
Level.
scale_anchors: An `int` for how much to scale this level to get the
original input shape.
"""
self
.
dtype
=
tf
.
keras
.
backend
.
floatx
()
if
masks
is
not
None
:
self
.
_num
=
len
(
masks
)
else
:
self
.
_num
=
tf
.
shape
(
anchors
)[
0
]
if
masks
is
not
None
:
anchors
=
[
anchors
[
mask
]
for
mask
in
masks
]
self
.
_scale_anchors
=
scale_anchors
self
.
_anchors
=
tf
.
convert_to_tensor
(
anchors
)
return
def
_build_grid_points
(
self
,
lwidth
,
lheight
,
anchors
,
dtype
):
"""Generate a grid that is used to detemine the relative centers
of the bounding boxs. """
with
tf
.
name_scope
(
'center_grid'
):
y
=
tf
.
range
(
0
,
lheight
)
x
=
tf
.
range
(
0
,
lwidth
)
num
=
tf
.
shape
(
anchors
)[
0
]
x_left
=
tf
.
tile
(
tf
.
transpose
(
tf
.
expand_dims
(
y
,
axis
=-
1
),
perm
=
[
1
,
0
]),
[
lwidth
,
1
])
y_left
=
tf
.
tile
(
tf
.
expand_dims
(
x
,
axis
=-
1
),
[
1
,
lheight
])
x_y
=
tf
.
stack
([
x_left
,
y_left
],
axis
=-
1
)
x_y
=
tf
.
cast
(
x_y
,
dtype
=
dtype
)
x_y
=
tf
.
expand_dims
(
tf
.
tile
(
tf
.
expand_dims
(
x_y
,
axis
=-
2
),
[
1
,
1
,
num
,
1
]),
axis
=
0
)
return
x_y
def
_build_anchor_grid
(
self
,
anchors
,
dtype
):
"""Get the transformed anchor boxes for each dimention. """
with
tf
.
name_scope
(
'anchor_grid'
):
num
=
tf
.
shape
(
anchors
)[
0
]
anchors
=
tf
.
cast
(
anchors
,
dtype
=
dtype
)
anchors
=
tf
.
reshape
(
anchors
,
[
1
,
1
,
1
,
num
,
2
])
return
anchors
def
_extend_batch
(
self
,
grid
,
batch_size
):
return
tf
.
tile
(
grid
,
[
batch_size
,
1
,
1
,
1
,
1
])
def
__call__
(
self
,
width
,
height
,
batch_size
,
dtype
=
None
):
if
dtype
is
None
:
self
.
dtype
=
tf
.
keras
.
backend
.
floatx
()
else
:
self
.
dtype
=
dtype
grid_points
=
self
.
_build_grid_points
(
width
,
height
,
self
.
_anchors
,
self
.
dtype
)
anchor_grid
=
self
.
_build_anchor_grid
(
tf
.
cast
(
self
.
_anchors
,
self
.
dtype
)
/
tf
.
cast
(
self
.
_scale_anchors
,
self
.
dtype
),
self
.
dtype
)
grid_points
=
self
.
_extend_batch
(
grid_points
,
batch_size
)
anchor_grid
=
self
.
_extend_batch
(
anchor_grid
,
batch_size
)
return
grid_points
,
anchor_grid
TILE_SIZE
=
50
class
PairWiseSearch
(
object
):
"""This method applies a pairwise search between the ground truth
and the labels. The goal is to indicate the locations where the
predictions overlap with ground truth for dynamic ground
truth constructions."""
def
__init__
(
self
,
iou_type
=
'iou'
,
any
=
True
,
min_conf
=
0.0
,
track_boxes
=
False
,
track_classes
=
False
):
"""Initialization of Pair Wise Search.
Args:
iou_type: An `str` for the iou type to use.
any: A `bool` for any match(no class match).
min_conf: An `int` for minimum confidence threshold.
track_boxes: A `bool` dynamic box assignment.
track_classes: A `bool` dynamic class assignment.
"""
self
.
iou_type
=
iou_type
self
.
_any
=
any
self
.
_min_conf
=
min_conf
self
.
_track_boxes
=
track_boxes
self
.
_track_classes
=
track_classes
return
def
box_iou
(
self
,
true_box
,
pred_box
):
# based on the type of loss, compute the iou loss for a box
# compute_<name> indicated the type of iou to use
if
self
.
iou_type
==
'giou'
:
_
,
iou
=
box_ops
.
compute_giou
(
true_box
,
pred_box
)
elif
self
.
iou_type
==
'ciou'
:
_
,
iou
=
box_ops
.
compute_ciou
(
true_box
,
pred_box
)
else
:
iou
=
box_ops
.
compute_iou
(
true_box
,
pred_box
)
return
iou
def
_search_body
(
self
,
pred_box
,
pred_class
,
boxes
,
classes
,
running_boxes
,
running_classes
,
max_iou
,
idx
):
# capture the batch size to be used, and gather a slice of
# boxes from the ground truth. currently TILE_SIZE = 50, to
# save memory
batch_size
=
tf
.
shape
(
boxes
)[
0
]
box_slice
=
tf
.
slice
(
boxes
,
[
0
,
idx
*
TILE_SIZE
,
0
],
[
batch_size
,
TILE_SIZE
,
4
])
# match the dimentions of the slice to the model predictions
# shape: [batch_size, 1, 1, num, TILE_SIZE, 4]
box_slice
=
tf
.
expand_dims
(
box_slice
,
axis
=
1
)
box_slice
=
tf
.
expand_dims
(
box_slice
,
axis
=
1
)
box_slice
=
tf
.
expand_dims
(
box_slice
,
axis
=
1
)
box_grid
=
tf
.
expand_dims
(
pred_box
,
axis
=-
2
)
# capture the classes
class_slice
=
tf
.
slice
(
classes
,
[
0
,
idx
*
TILE_SIZE
],
[
batch_size
,
TILE_SIZE
])
class_slice
=
tf
.
expand_dims
(
class_slice
,
axis
=
1
)
class_slice
=
tf
.
expand_dims
(
class_slice
,
axis
=
1
)
class_slice
=
tf
.
expand_dims
(
class_slice
,
axis
=
1
)
iou
=
self
.
box_iou
(
box_slice
,
box_grid
)
if
self
.
_min_conf
>
0.0
:
if
not
self
.
_any
:
class_grid
=
tf
.
expand_dims
(
pred_class
,
axis
=-
2
)
class_mask
=
tf
.
one_hot
(
tf
.
cast
(
class_slice
,
tf
.
int32
),
depth
=
tf
.
shape
(
pred_class
)[
-
1
],
dtype
=
pred_class
.
dtype
)
class_mask
=
tf
.
reduce_any
(
tf
.
equal
(
class_mask
,
class_grid
),
axis
=-
1
)
else
:
class_mask
=
tf
.
reduce_max
(
pred_class
,
axis
=-
1
,
keepdims
=
True
)
class_mask
=
tf
.
cast
(
class_mask
,
iou
.
dtype
)
iou
*=
class_mask
max_iou_
=
tf
.
concat
([
max_iou
,
iou
],
axis
=-
1
)
max_iou
=
tf
.
reduce_max
(
max_iou_
,
axis
=-
1
,
keepdims
=
True
)
ind
=
tf
.
expand_dims
(
tf
.
argmax
(
max_iou_
,
axis
=-
1
),
axis
=-
1
)
if
self
.
_track_boxes
:
running_boxes
=
tf
.
expand_dims
(
running_boxes
,
axis
=-
2
)
box_slice
=
tf
.
zeros_like
(
running_boxes
)
+
box_slice
box_slice
=
tf
.
concat
([
running_boxes
,
box_slice
],
axis
=-
2
)
running_boxes
=
tf
.
gather_nd
(
box_slice
,
ind
,
batch_dims
=
4
)
if
self
.
_track_classes
:
running_classes
=
tf
.
expand_dims
(
running_classes
,
axis
=-
1
)
class_slice
=
tf
.
zeros_like
(
running_classes
)
+
class_slice
class_slice
=
tf
.
concat
([
running_classes
,
class_slice
],
axis
=-
1
)
running_classes
=
tf
.
gather_nd
(
class_slice
,
ind
,
batch_dims
=
4
)
return
(
pred_box
,
pred_class
,
boxes
,
classes
,
running_boxes
,
running_classes
,
max_iou
,
idx
+
1
)
def
__call__
(
self
,
pred_boxes
,
pred_classes
,
boxes
,
classes
,
scale
=
None
,
yxyx
=
True
,
clip_thresh
=
0.0
):
num_boxes
=
tf
.
shape
(
boxes
)[
-
2
]
num_tiles
=
(
num_boxes
//
TILE_SIZE
)
-
1
if
yxyx
:
boxes
=
box_ops
.
yxyx_to_xcycwh
(
boxes
)
if
scale
is
not
None
:
boxes
=
boxes
*
tf
.
stop_gradient
(
scale
)
if
self
.
_min_conf
>
0.0
:
pred_classes
=
tf
.
cast
(
pred_classes
>
self
.
_min_conf
,
pred_classes
.
dtype
)
def
_loop_cond
(
pred_box
,
pred_class
,
boxes
,
classes
,
running_boxes
,
running_classes
,
max_iou
,
idx
):
# check that the slice has boxes that all zeros
batch_size
=
tf
.
shape
(
boxes
)[
0
]
box_slice
=
tf
.
slice
(
boxes
,
[
0
,
idx
*
TILE_SIZE
,
0
],
[
batch_size
,
TILE_SIZE
,
4
])
return
tf
.
logical_and
(
idx
<
num_tiles
,
tf
.
math
.
greater
(
tf
.
reduce_sum
(
box_slice
),
0
))
running_boxes
=
tf
.
zeros_like
(
pred_boxes
)
running_classes
=
tf
.
zeros_like
(
tf
.
reduce_sum
(
running_boxes
,
axis
=-
1
))
max_iou
=
tf
.
zeros_like
(
tf
.
reduce_sum
(
running_boxes
,
axis
=-
1
))
max_iou
=
tf
.
expand_dims
(
max_iou
,
axis
=-
1
)
(
pred_boxes
,
pred_classes
,
boxes
,
classes
,
running_boxes
,
running_classes
,
max_iou
,
idx
)
=
tf
.
while_loop
(
_loop_cond
,
self
.
_search_body
,
[
pred_boxes
,
pred_classes
,
boxes
,
classes
,
running_boxes
,
running_classes
,
max_iou
,
tf
.
constant
(
0
)
])
mask
=
tf
.
cast
(
max_iou
>
clip_thresh
,
running_boxes
.
dtype
)
running_boxes
*=
mask
running_classes
*=
tf
.
squeeze
(
mask
,
axis
=-
1
)
max_iou
*=
mask
max_iou
=
tf
.
squeeze
(
max_iou
,
axis
=-
1
)
mask
=
tf
.
squeeze
(
mask
,
axis
=-
1
)
return
(
tf
.
stop_gradient
(
running_boxes
),
tf
.
stop_gradient
(
running_classes
),
tf
.
stop_gradient
(
max_iou
),
tf
.
stop_gradient
(
mask
))
def
avgiou
(
iou
):
"""Computes the average intersection over union without counting locations
where the iou is zero.
Args:
iou: A `Tensor` representing the iou values.
Returns:
tf.stop_gradient(avg_iou): A `Tensor` representing average
intersection over union.
"""
iou_sum
=
tf
.
reduce_sum
(
iou
,
axis
=
tf
.
range
(
1
,
tf
.
shape
(
tf
.
shape
(
iou
))[
0
]))
counts
=
tf
.
cast
(
tf
.
math
.
count_nonzero
(
iou
,
axis
=
tf
.
range
(
1
,
tf
.
shape
(
tf
.
shape
(
iou
))[
0
])),
iou
.
dtype
)
avg_iou
=
tf
.
reduce_mean
(
math_ops
.
divide_no_nan
(
iou_sum
,
counts
))
return
tf
.
stop_gradient
(
avg_iou
)
def
_scale_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
scale_xy
):
# split the boxes
pred_xy
=
encoded_boxes
[...,
0
:
2
]
pred_wh
=
encoded_boxes
[...,
2
:
4
]
# build a scaling tensor to get the offset of th ebox relative to the image
scaler
=
tf
.
convert_to_tensor
([
height
,
width
,
height
,
width
])
scale_xy
=
tf
.
cast
(
scale_xy
,
encoded_boxes
.
dtype
)
# apply the sigmoid
pred_xy
=
tf
.
math
.
sigmoid
(
pred_xy
)
# scale the centers and find the offset of each box relative to
# their center pixel
pred_xy
=
pred_xy
*
scale_xy
-
0.5
*
(
scale_xy
-
1
)
# scale the offsets and add them to the grid points or a tensor that is
# the realtive location of each pixel
box_xy
=
grid_points
+
pred_xy
# scale the width and height of the predictions and corlate them
# to anchor boxes
box_wh
=
tf
.
math
.
exp
(
pred_wh
)
*
anchor_grid
# build the final predicted box
scaled_box
=
tf
.
concat
([
box_xy
,
box_wh
],
axis
=-
1
)
pred_box
=
scaled_box
/
scaler
# shift scaled boxes
scaled_box
=
tf
.
concat
([
pred_xy
,
box_wh
],
axis
=-
1
)
return
(
scaler
,
scaled_box
,
pred_box
)
@
tf
.
custom_gradient
def
_darknet_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
max_delta
,
scale_xy
):
(
scaler
,
scaled_box
,
pred_box
)
=
_scale_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
scale_xy
)
def
delta
(
dy_scaler
,
dy_scaled
,
dy
):
dy_xy
,
dy_wh
=
tf
.
split
(
dy
,
2
,
axis
=-
1
)
dy_xy_
,
dy_wh_
=
tf
.
split
(
dy_scaled
,
2
,
axis
=-
1
)
# add all the gradients that may have been applied to the
# boxes and those that have been applied to the width and height
dy_wh
+=
dy_wh_
dy_xy
+=
dy_xy_
# propagate the exponential applied to the width and height in
# order to ensure the gradient propagated is of the correct
# magnitude
pred_wh
=
encoded_boxes
[...,
2
:
4
]
dy_wh
*=
tf
.
math
.
exp
(
pred_wh
)
dbox
=
tf
.
concat
([
dy_xy
,
dy_wh
],
axis
=-
1
)
# apply the gradient clipping to xy and wh
dbox
=
math_ops
.
rm_nan_inf
(
dbox
)
delta
=
tf
.
cast
(
max_delta
,
dbox
.
dtype
)
dbox
=
tf
.
clip_by_value
(
dbox
,
-
delta
,
delta
)
return
dbox
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
return
(
scaler
,
scaled_box
,
pred_box
),
delta
def
_new_coord_scale_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
scale_xy
):
# split the boxes
pred_xy
=
encoded_boxes
[...,
0
:
2
]
pred_wh
=
encoded_boxes
[...,
2
:
4
]
# build a scaling tensor to get the offset of th ebox relative to the image
scaler
=
tf
.
convert_to_tensor
([
height
,
width
,
height
,
width
])
scale_xy
=
tf
.
cast
(
scale_xy
,
pred_xy
.
dtype
)
# apply the sigmoid
pred_xy
=
tf
.
math
.
sigmoid
(
pred_xy
)
pred_wh
=
tf
.
math
.
sigmoid
(
pred_wh
)
# scale the xy offset predictions according to the config
pred_xy
=
pred_xy
*
scale_xy
-
0.5
*
(
scale_xy
-
1
)
# find the true offset from the grid points and the scaler
# where the grid points are the relative offset of each pixel with
# in the image
box_xy
=
grid_points
+
pred_xy
# decode the widht and height of the boxes and correlate them
# to the anchor boxes
box_wh
=
(
2
*
pred_wh
)
**
2
*
anchor_grid
# build the final boxes
scaled_box
=
tf
.
concat
([
box_xy
,
box_wh
],
axis
=-
1
)
pred_box
=
scaled_box
/
scaler
# shift scaled boxes
scaled_box
=
tf
.
concat
([
pred_xy
,
box_wh
],
axis
=-
1
)
return
(
scaler
,
scaled_box
,
pred_box
)
@
tf
.
custom_gradient
def
_darknet_new_coord_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
max_delta
,
scale_xy
):
(
scaler
,
scaled_box
,
pred_box
)
=
_new_coord_scale_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
scale_xy
)
def
delta
(
dy_scaler
,
dy_scaled
,
dy
):
dy_xy
,
dy_wh
=
tf
.
split
(
dy
,
2
,
axis
=-
1
)
dy_xy_
,
dy_wh_
=
tf
.
split
(
dy_scaled
,
2
,
axis
=-
1
)
# add all the gradients that may have been applied to the
# boxes and those that have been applied to the width and height
dy_wh
+=
dy_wh_
dy_xy
+=
dy_xy_
dbox
=
tf
.
concat
([
dy_xy
,
dy_wh
],
axis
=-
1
)
# apply the gradient clipping to xy and wh
dbox
=
math_ops
.
rm_nan_inf
(
dbox
)
delta
=
tf
.
cast
(
max_delta
,
dbox
.
dtype
)
dbox
=
tf
.
clip_by_value
(
dbox
,
-
delta
,
delta
)
return
dbox
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
return
(
scaler
,
scaled_box
,
pred_box
),
delta
def
_anchor_free_scale_boxes
(
encoded_boxes
,
width
,
height
,
stride
,
grid_points
,
scale_xy
):
# split the boxes
pred_xy
=
encoded_boxes
[...,
0
:
2
]
pred_wh
=
encoded_boxes
[...,
2
:
4
]
# build a scaling tensor to get the offset of th ebox relative to the image
scaler
=
tf
.
convert_to_tensor
([
height
,
width
,
height
,
width
])
scale_xy
=
tf
.
cast
(
scale_xy
,
encoded_boxes
.
dtype
)
# scale the centers and find the offset of each box relative to
# their center pixel
pred_xy
=
pred_xy
*
scale_xy
-
0.5
*
(
scale_xy
-
1
)
# scale the offsets and add them to the grid points or a tensor that is
# the realtive location of each pixel
box_xy
=
(
grid_points
+
pred_xy
)
*
stride
# scale the width and height of the predictions and corlate them
# to anchor boxes
box_wh
=
tf
.
math
.
exp
(
pred_wh
)
*
stride
# build the final predicted box
scaled_box
=
tf
.
concat
([
box_xy
,
box_wh
],
axis
=-
1
)
pred_box
=
scaled_box
/
scaler
return
(
scaler
,
scaled_box
,
pred_box
)
def
get_predicted_box
(
width
,
height
,
encoded_boxes
,
anchor_grid
,
grid_points
,
scale_xy
,
stride
,
darknet
=
False
,
box_type
=
"original"
,
max_delta
=
np
.
inf
):
"""Decodes the predicted boxes from the model format to a usable
[x, y, w, h] format for use in the loss function as well as for use
within the detection generator.
Args:
width: A `float` scalar indicating the width of the prediction layer.
height: A `float` scalar indicating the height of the prediction layer
encoded_boxes: A `Tensor` of shape [..., height, width, 4] holding encoded
boxes.
anchor_grid: A `Tensor` of shape [..., 1, 1, 2] holding the anchor boxes
organized for box decoding, box width and height.
grid_points: A `Tensor` of shape [..., height, width, 2] holding the anchor
boxes for decoding the box centers.
scale_xy: A `float` scaler used to indicate the range for each center
outside of its given [..., i, j, 4] index, where i and j are indexing
pixels along the width and height of the predicted output map.
stride: An `int` defining the amount of down stride realtive to the input
image.
darknet: A `bool` used to select between custom gradient and default
autograd.
box_type: An `str` indicating the type of box encoding that is being used.
max_delta: A `float` scaler used for gradient clipping in back propagation.
Returns:
scaler: A `Tensor` of shape [4] returned to allow the scaling of the ground
truth boxes to be of the same magnitude as the decoded predicted boxes.
scaled_box: A `Tensor` of shape [..., height, width, 4] with the predicted
boxes.
pred_box: A `Tensor` of shape [..., height, width, 4] with the predicted
boxes divided by the scaler parameter used to put all boxes in the [0, 1]
range.
"""
if
box_type
==
'anchor_free'
:
(
scaler
,
scaled_box
,
pred_box
)
=
_anchor_free_scale_boxes
(
encoded_boxes
,
width
,
height
,
stride
,
grid_points
,
scale_xy
)
elif
darknet
:
# if we are using the darknet loss we shoud nto propagate the
# decoding of the box
if
box_type
==
'scaled'
:
(
scaler
,
scaled_box
,
pred_box
)
=
_darknet_new_coord_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
max_delta
,
scale_xy
)
else
:
(
scaler
,
scaled_box
,
pred_box
)
=
_darknet_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
max_delta
,
scale_xy
)
else
:
# if we are using the scaled loss we should propagate the decoding of
# the boxes
if
box_type
==
'scaled'
:
(
scaler
,
scaled_box
,
pred_box
)
=
_new_coord_scale_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
scale_xy
)
else
:
(
scaler
,
scaled_box
,
pred_box
)
=
_scale_boxes
(
encoded_boxes
,
width
,
height
,
anchor_grid
,
grid_points
,
scale_xy
)
return
(
scaler
,
scaled_box
,
pred_box
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment