Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ad3427f9
"docs/vscode:/vscode.git/clone" did not exist on "dd50d35f588a7aaaa03d0b36d775995edc1946cf"
Commit
ad3427f9
authored
Sep 21, 2021
by
Vishnu Banna
Browse files
loss function updates
parent
a15e242e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
28 deletions
+26
-28
official/vision/beta/projects/yolo/losses/yolo_loss.py
official/vision/beta/projects/yolo/losses/yolo_loss.py
+25
-27
official/vision/beta/projects/yolo/losses/yolo_loss_test.py
official/vision/beta/projects/yolo/losses/yolo_loss_test.py
+1
-1
No files found.
official/vision/beta/projects/yolo/losses/yolo_loss.py
View file @
ad3427f9
...
...
@@ -85,7 +85,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
max_delta: gradient clipping to apply to the box loss.
"""
self
.
_loss_type
=
loss_type
self
.
_classes
=
tf
.
constant
(
tf
.
cast
(
classes
,
dtype
=
tf
.
int32
))
self
.
_classes
=
classes
self
.
_num
=
tf
.
cast
(
len
(
mask
),
dtype
=
tf
.
int32
)
self
.
_truth_thresh
=
truth_thresh
self
.
_ignore_thresh
=
ignore_thresh
...
...
@@ -111,8 +111,8 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
max_delta
=
self
.
_max_delta
)
self
.
_decode_boxes
=
functools
.
partial
(
loss_utils
.
get_predicted_box
,
**
box_kwargs
)
self
.
_search_pairs
=
lambda
pred_boxes
,
pred_classes
,
boxes
,
classes
,
scale
,
yxyx
:
(
None
,
None
,
None
,
None
)
# pylint:disable=line-too-long
self
.
_
build_per_path_attributes
()
self
.
_
search_pairs
=
None
self
.
_build_per_path_attributes
()
def
box_loss
(
self
,
true_box
,
pred_box
,
darknet
=
False
):
...
...
@@ -136,14 +136,14 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
scale
=
None
):
"""Search of all groundtruths to associate groundtruths to predictions."""
if
self
.
_search_pairs
is
None
:
return
true_conf
,
tf
.
ones_like
(
true_conf
)
# Search all predictions against ground truths to find mathcing boxes for
# each pixel.
_
,
_
,
iou_max
,
_
=
self
.
_search_pairs
(
pred_boxes
,
pred_classes
,
boxes
,
classes
,
scale
=
scale
,
yxyx
=
True
)
if
iou_max
is
None
:
return
true_conf
,
tf
.
ones_like
(
true_conf
)
# Find the exact indexes to ignore and keep.
ignore_mask
=
tf
.
cast
(
iou_max
<
self
.
_ignore_thresh
,
pred_boxes
.
dtype
)
iou_mask
=
iou_max
>
self
.
_ignore_thresh
...
...
@@ -199,9 +199,6 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
grid_mask
)
=
self
.
_compute_loss
(
true_counts
,
inds
,
y_true
,
boxes
,
classes
,
y_pred
)
# Temporary metrics
box_loss
=
tf
.
stop_gradient
(
0.05
*
box_loss
/
self
.
_iou_normalizer
)
# Metric compute using done here to save time and resources.
sigmoid_conf
=
tf
.
stop_gradient
(
tf
.
sigmoid
(
pred_conf
))
iou
=
tf
.
stop_gradient
(
iou
)
...
...
@@ -314,8 +311,7 @@ class DarknetLoss(YoloLossBase):
anchor_grid
=
tf
.
stop_gradient
(
anchor_grid
)
# Split all the ground truths to use as seperate items in loss computation.
(
true_box
,
ind_mask
,
true_class
,
_
,
_
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
,
1
,
1
],
axis
=-
1
)
(
true_box
,
ind_mask
,
true_class
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
],
axis
=-
1
)
true_conf
=
tf
.
squeeze
(
true_conf
,
axis
=-
1
)
true_class
=
tf
.
squeeze
(
true_class
,
axis
=-
1
)
grid_mask
=
true_conf
...
...
@@ -439,6 +435,8 @@ class ScaledLoss(YoloLossBase):
if
self
.
_ignore_thresh
>
0.0
:
self
.
_search_pairs
=
loss_utils
.
PairWiseSearch
(
iou_type
=
self
.
_loss_type
,
any_match
=
False
,
min_conf
=
0.25
)
self
.
_cls_normalizer
=
self
.
_cls_normalizer
*
self
.
_classes
/
80
return
def
_compute_loss
(
self
,
true_counts
,
inds
,
y_true
,
boxes
,
classes
,
y_pred
):
...
...
@@ -457,8 +455,7 @@ class ScaledLoss(YoloLossBase):
width
,
height
,
batch_size
,
dtype
=
tf
.
float32
)
# Split the y_true list.
(
true_box
,
ind_mask
,
true_class
,
_
,
_
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
,
1
,
1
],
axis
=-
1
)
(
true_box
,
ind_mask
,
true_class
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
],
axis
=-
1
)
grid_mask
=
true_conf
=
tf
.
squeeze
(
true_conf
,
axis
=-
1
)
true_class
=
tf
.
squeeze
(
true_class
,
axis
=-
1
)
num_objs
=
tf
.
cast
(
tf
.
reduce_sum
(
ind_mask
),
dtype
=
y_pred
.
dtype
)
...
...
@@ -469,7 +466,7 @@ class ScaledLoss(YoloLossBase):
pred_box
,
pred_conf
,
pred_class
=
tf
.
split
(
y_pred
,
[
4
,
1
,
-
1
],
axis
=-
1
)
# Decode the boxes for loss compute.
scale
,
pred_box
,
_
=
self
.
_decode_boxes
(
scale
,
pred_box
,
pbg
=
self
.
_decode_boxes
(
fwidth
,
fheight
,
pred_box
,
anchor_grid
,
grid_points
,
darknet
=
False
)
# If the ignore threshold is enabled, search all boxes ignore all
...
...
@@ -477,20 +474,24 @@ class ScaledLoss(YoloLossBase):
# noted ground truth list.
if
self
.
_ignore_thresh
!=
0.0
:
(
_
,
obj_mask
)
=
self
.
_tiled_global_box_search
(
p
red_box
,
p
bg
,
tf
.
stop_gradient
(
tf
.
sigmoid
(
pred_class
)),
boxes
,
classes
,
true_conf
,
smoothed
=
False
,
scale
=
scal
e
)
scale
=
Non
e
)
# Scale and shift and select the ground truth boxes
# and predictions to the prediciton domain.
offset
=
tf
.
cast
(
tf
.
gather_nd
(
grid_points
,
inds
,
batch_dims
=
1
),
true_box
.
dtype
)
offset
=
tf
.
concat
([
offset
,
tf
.
zeros_like
(
offset
)],
axis
=-
1
)
true_box
=
loss_utils
.
apply_mask
(
ind_mask
,
(
scale
*
true_box
)
-
offset
)
if
self
.
_box_type
==
"anchor_free"
:
true_box
=
loss_utils
.
apply_mask
(
ind_mask
,
(
scale
*
self
.
_path_stride
*
true_box
))
else
:
offset
=
tf
.
cast
(
tf
.
gather_nd
(
grid_points
,
inds
,
batch_dims
=
1
),
true_box
.
dtype
)
offset
=
tf
.
concat
([
offset
,
tf
.
zeros_like
(
offset
)],
axis
=-
1
)
true_box
=
loss_utils
.
apply_mask
(
ind_mask
,
(
scale
*
true_box
)
-
offset
)
pred_box
=
loss_utils
.
apply_mask
(
ind_mask
,
tf
.
gather_nd
(
pred_box
,
inds
,
batch_dims
=
1
))
...
...
@@ -523,7 +524,9 @@ class ScaledLoss(YoloLossBase):
tf
.
expand_dims
(
true_conf
,
axis
=-
1
),
pred_conf
,
from_logits
=
True
)
if
self
.
_ignore_thresh
!=
0.0
:
bce
=
loss_utils
.
apply_mask
(
obj_mask
,
bce
)
conf_loss
=
tf
.
reduce_mean
(
bce
)
conf_loss
=
tf
.
reduce_sum
(
bce
)
/
tf
.
reduce_sum
(
obj_mask
)
else
:
conf_loss
=
tf
.
reduce_mean
(
bce
)
# Compute the cross entropy loss for the class maps.
class_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
...
...
@@ -667,7 +670,7 @@ class YoloLoss:
update_on_repeat
=
update_on_repeat
,
label_smoothing
=
label_smoothing
)
def
__call__
(
self
,
ground_truth
,
predictions
,
use_reduced_logs
=
True
):
def
__call__
(
self
,
ground_truth
,
predictions
):
metric_dict
=
collections
.
defaultdict
(
dict
)
metric_dict
[
'net'
][
'box'
]
=
0
metric_dict
[
'net'
][
'class'
]
=
0
...
...
@@ -703,11 +706,6 @@ class YoloLoss:
metric_dict
[
key
][
'avg_iou'
]
=
tf
.
stop_gradient
(
avg_iou
)
metric_dict
[
key
][
'avg_obj'
]
=
tf
.
stop_gradient
(
avg_obj
)
if
not
use_reduced_logs
:
metric_dict
[
key
][
'conf_loss'
]
=
tf
.
stop_gradient
(
loss_conf
)
metric_dict
[
key
][
'box_loss'
]
=
tf
.
stop_gradient
(
loss_box
)
metric_dict
[
key
][
'class_loss'
]
=
tf
.
stop_gradient
(
loss_class
)
metric_dict
[
'net'
][
'box'
]
+=
tf
.
stop_gradient
(
loss_box
)
metric_dict
[
'net'
][
'class'
]
+=
tf
.
stop_gradient
(
loss_class
)
metric_dict
[
'net'
][
'conf'
]
+=
tf
.
stop_gradient
(
loss_conf
)
...
...
official/vision/beta/projects/yolo/losses/yolo_loss_test.py
View file @
ad3427f9
...
...
@@ -79,7 +79,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
'4'
:
[
1
,
300
,
3
],
'5'
:
[
1
,
300
,
3
]
},
tf
.
int32
)
truths
=
inpdict
({
'3'
:
[
1
,
300
,
8
],
'4'
:
[
1
,
300
,
8
],
'5'
:
[
1
,
300
,
8
]})
truths
=
inpdict
({
'3'
:
[
1
,
300
,
6
],
'4'
:
[
1
,
300
,
6
],
'5'
:
[
1
,
300
,
6
]})
boxes
=
tf
.
ones
([
1
,
300
,
4
],
dtype
=
tf
.
float32
)
classes
=
tf
.
ones
([
1
,
300
],
dtype
=
tf
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment