Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cee4b75e
"src/vscode:/vscode.git/clone" did not exist on "62c2c547dbc9eee39d4ddc310dbd477df20c754b"
Commit
cee4b75e
authored
Mar 31, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 366120838
parent
cdf815f8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
10 deletions
+15
-10
official/vision/beta/modeling/retinanet_model.py
official/vision/beta/modeling/retinanet_model.py
+13
-8
official/vision/beta/modeling/retinanet_model_test.py
official/vision/beta/modeling/retinanet_model_test.py
+2
-2
No files found.
official/vision/beta/modeling/retinanet_model.py
View file @
cee4b75e
...
...
@@ -97,26 +97,31 @@ class RetinaNetModel(tf.keras.Model):
raw_scores
,
raw_boxes
,
raw_attributes
=
self
.
head
(
features
)
if
training
:
return
{
outputs
=
{
'cls_outputs'
:
raw_scores
,
'box_outputs'
:
raw_boxes
,
'att_outputs'
:
raw_attributes
,
}
if
raw_attributes
:
outputs
.
update
({
'att_outputs'
:
raw_attributes
})
return
outputs
else
:
# Post-processing.
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
return
{
final_results
=
self
.
detection_generator
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
)
outputs
=
{
'detection_boxes'
:
final_results
[
'detection_boxes'
],
'detection_scores'
:
final_results
[
'detection_scores'
],
'detection_classes'
:
final_results
[
'detection_classes'
],
'detection_attributes'
:
final_results
[
'detection_attributes'
],
'num_detections'
:
final_results
[
'num_detections'
],
'cls_outputs'
:
raw_scores
,
'box_outputs'
:
raw_boxes
,
'att_outputs'
:
raw_attributes
,
}
if
raw_attributes
:
outputs
.
update
({
'att_outputs'
:
raw_attributes
,
'detection_attributes'
:
final_results
[
'detection_attributes'
],
})
return
outputs
@
property
def
checkpoint_items
(
self
):
...
...
official/vision/beta/modeling/retinanet_model_test.py
View file @
cee4b75e
...
...
@@ -160,7 +160,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
if
training
:
cls_outputs
=
model_outputs
[
'cls_outputs'
]
box_outputs
=
model_outputs
[
'box_outputs'
]
att_outputs
=
model_outputs
[
'att_outputs'
]
for
level
in
range
(
min_level
,
max_level
+
1
):
self
.
assertIn
(
str
(
level
),
cls_outputs
)
self
.
assertIn
(
str
(
level
),
box_outputs
)
...
...
@@ -177,6 +176,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
4
*
num_anchors_per_location
],
box_outputs
[
str
(
level
)].
numpy
().
shape
)
if
has_att_heads
:
att_outputs
=
model_outputs
[
'att_outputs'
]
for
att
in
att_outputs
.
values
():
self
.
assertAllEqual
([
2
,
image_size
[
0
]
//
2
**
level
,
image_size
[
1
]
//
2
**
level
,
...
...
@@ -186,7 +186,6 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertIn
(
'detection_boxes'
,
model_outputs
)
self
.
assertIn
(
'detection_scores'
,
model_outputs
)
self
.
assertIn
(
'detection_classes'
,
model_outputs
)
self
.
assertIn
(
'detection_attributes'
,
model_outputs
)
self
.
assertIn
(
'num_detections'
,
model_outputs
)
self
.
assertAllEqual
(
[
2
,
10
,
4
],
model_outputs
[
'detection_boxes'
].
numpy
().
shape
)
...
...
@@ -197,6 +196,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
(
[
2
,],
model_outputs
[
'num_detections'
].
numpy
().
shape
)
if
has_att_heads
:
self
.
assertIn
(
'detection_attributes'
,
model_outputs
)
self
.
assertAllEqual
(
[
2
,
10
,
1
],
model_outputs
[
'detection_attributes'
][
'depth'
].
numpy
().
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment