Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3fc55e9e
"vscode:/vscode.git/clone" did not exist on "b4b7ed28bf76732fe27f40605d85a405f8df1fcf"
Commit
3fc55e9e
authored
Jun 02, 2021
by
Xianzhi Du
Committed by
A. Unique TensorFlower
Jun 02, 2021
Browse files
Internal change
PiperOrigin-RevId: 377216027
parent
0a770680
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
0 deletions
+46
-0
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+46
-0
No files found.
official/vision/beta/tasks/retinanet.py
View file @
3fc55e9e
...
@@ -133,12 +133,54 @@ class RetinaNetTask(base_task.Task):
...
@@ -133,12 +133,54 @@ class RetinaNetTask(base_task.Task):
return
dataset
return
dataset
def
build_attribute_loss
(
self
,
attribute_heads
:
List
[
exp_cfg
.
AttributeHead
],
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
box_sample_weight
:
tf
.
Tensor
)
->
float
:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
attribute_loss
=
0.0
for
head
in
attribute_heads
:
if
head
.
name
not
in
labels
[
'attribute_targets'
]:
raise
ValueError
(
f
'Attribute
{
head
.
name
}
not found in label targets.'
)
if
head
.
name
not
in
outputs
[
'attribute_outputs'
]:
raise
ValueError
(
f
'Attribute
{
head
.
name
}
not found in model outputs.'
)
y_true_att
=
keras_cv
.
losses
.
multi_level_flatten
(
labels
[
'attribute_targets'
][
head
.
name
],
last_dim
=
head
.
size
)
y_pred_att
=
keras_cv
.
losses
.
multi_level_flatten
(
outputs
[
'attribute_outputs'
][
head
.
name
],
last_dim
=
head
.
size
)
if
head
.
type
==
'regression'
:
att_loss_fn
=
tf
.
keras
.
losses
.
Huber
(
1.0
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
att_loss
=
att_loss_fn
(
y_true
=
y_true_att
,
y_pred
=
y_pred_att
,
sample_weight
=
box_sample_weight
)
else
:
raise
ValueError
(
f
'Attribute type
{
head
.
type
}
not supported.'
)
attribute_loss
+=
att_loss
return
attribute_loss
def
build_losses
(
self
,
def
build_losses
(
self
,
outputs
:
Mapping
[
str
,
Any
],
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
aux_losses
:
Optional
[
Any
]
=
None
):
aux_losses
:
Optional
[
Any
]
=
None
):
"""Build RetinaNet losses."""
"""Build RetinaNet losses."""
params
=
self
.
task_config
params
=
self
.
task_config
attribute_heads
=
self
.
task_config
.
model
.
head
.
attribute_heads
cls_loss_fn
=
keras_cv
.
losses
.
FocalLoss
(
cls_loss_fn
=
keras_cv
.
losses
.
FocalLoss
(
alpha
=
params
.
losses
.
focal_loss_alpha
,
alpha
=
params
.
losses
.
focal_loss_alpha
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
...
@@ -170,6 +212,10 @@ class RetinaNetTask(base_task.Task):
...
@@ -170,6 +212,10 @@ class RetinaNetTask(base_task.Task):
model_loss
=
cls_loss
+
params
.
losses
.
box_loss_weight
*
box_loss
model_loss
=
cls_loss
+
params
.
losses
.
box_loss_weight
*
box_loss
if
attribute_heads
:
model_loss
+=
self
.
build_attribute_loss
(
attribute_heads
,
outputs
,
labels
,
box_sample_weight
)
total_loss
=
model_loss
total_loss
=
model_loss
if
aux_losses
:
if
aux_losses
:
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment