Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
91a1ce9b
Commit
91a1ce9b
authored
Dec 04, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Dec 04, 2019
Browse files
Code cleanup.
PiperOrigin-RevId: 283837279
parent
5b25005c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
31 deletions
+36
-31
official/vision/detection/evaluation/coco_evaluator.py
official/vision/detection/evaluation/coco_evaluator.py
+29
-0
official/vision/detection/evaluation/factory.py
official/vision/detection/evaluation/factory.py
+1
-1
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+6
-30
No files found.
official/vision/detection/evaluation/coco_evaluator.py
View file @
91a1ce9b
...
@@ -42,6 +42,35 @@ from official.vision.detection.evaluation import coco_utils
...
@@ -42,6 +42,35 @@ from official.vision.detection.evaluation import coco_utils
from
official.vision.detection.utils
import
class_utils
from
official.vision.detection.utils
import
class_utils
class
MetricWrapper
(
object
):
# This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.
def
__init__
(
self
,
evaluator
):
self
.
_evaluator
=
evaluator
def
update_state
(
self
,
y_true
,
y_pred
):
labels
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
y_true
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
y_pred
)
groundtruths
=
{}
predictions
=
{}
for
key
,
val
in
outputs
.
items
():
if
isinstance
(
val
,
tuple
):
val
=
np
.
concatenate
(
val
)
predictions
[
key
]
=
val
for
key
,
val
in
labels
.
items
():
if
isinstance
(
val
,
tuple
):
val
=
np
.
concatenate
(
val
)
groundtruths
[
key
]
=
val
self
.
_evaluator
.
update
(
predictions
,
groundtruths
)
def
result
(
self
):
return
self
.
_evaluator
.
evaluate
()
def
reset_states
(
self
):
return
self
.
_evaluator
.
reset
()
class
COCOEvaluator
(
object
):
class
COCOEvaluator
(
object
):
"""COCO evaluation metric class."""
"""COCO evaluation metric class."""
...
...
official/vision/detection/evaluation/factory.py
View file @
91a1ce9b
...
@@ -32,4 +32,4 @@ def evaluator_generator(params):
...
@@ -32,4 +32,4 @@ def evaluator_generator(params):
else
:
else
:
raise
ValueError
(
'Evaluator %s is not supported.'
%
params
.
type
)
raise
ValueError
(
'Evaluator %s is not supported.'
%
params
.
type
)
return
evaluator
return
coco_
evaluator
.
MetricWrapper
(
evaluator
)
official/vision/detection/modeling/retinanet_model.py
View file @
91a1ce9b
...
@@ -32,35 +32,6 @@ from official.vision.detection.modeling.architecture import factory
...
@@ -32,35 +32,6 @@ from official.vision.detection.modeling.architecture import factory
from
official.vision.detection.ops
import
postprocess_ops
from
official.vision.detection.ops
import
postprocess_ops
class
COCOMetrics
(
object
):
# This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.
def
__init__
(
self
,
params
):
self
.
_evaluator
=
eval_factory
.
evaluator_generator
(
params
.
eval
)
def
update_state
(
self
,
y_true
,
y_pred
):
labels
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
y_true
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
x
.
numpy
(),
y_pred
)
groundtruths
=
{}
predictions
=
{}
for
key
,
val
in
outputs
.
items
():
if
isinstance
(
val
,
tuple
):
val
=
np
.
concatenate
(
val
)
predictions
[
key
]
=
val
for
key
,
val
in
labels
.
items
():
if
isinstance
(
val
,
tuple
):
val
=
np
.
concatenate
(
val
)
groundtruths
[
key
]
=
val
self
.
_evaluator
.
update
(
predictions
,
groundtruths
)
def
result
(
self
):
return
self
.
_evaluator
.
evaluate
()
def
reset_states
(
self
):
return
self
.
_evaluator
.
reset
()
class
RetinanetModel
(
base_model
.
Model
):
class
RetinanetModel
(
base_model
.
Model
):
"""RetinaNet model function."""
"""RetinaNet model function."""
...
@@ -97,6 +68,11 @@ class RetinanetModel(base_model.Model):
...
@@ -97,6 +68,11 @@ class RetinanetModel(base_model.Model):
dtype
=
tf
.
bfloat16
if
self
.
_use_bfloat16
else
tf
.
float32
)
dtype
=
tf
.
bfloat16
if
self
.
_use_bfloat16
else
tf
.
float32
)
def
build_outputs
(
self
,
inputs
,
mode
):
def
build_outputs
(
self
,
inputs
,
mode
):
# If the input image is transposed (from NHWC to HWCN), we need to revert it
# back to the original shape before it's used in the computation.
if
self
.
_transpose_input
:
inputs
=
tf
.
transpose
(
inputs
,
[
3
,
0
,
1
,
2
])
backbone_features
=
self
.
_backbone_fn
(
backbone_features
=
self
.
_backbone_fn
(
inputs
,
is_training
=
(
mode
==
mode_keys
.
TRAIN
))
inputs
,
is_training
=
(
mode
==
mode_keys
.
TRAIN
))
fpn_features
=
self
.
_fpn_fn
(
fpn_features
=
self
.
_fpn_fn
(
...
@@ -192,4 +168,4 @@ class RetinanetModel(base_model.Model):
...
@@ -192,4 +168,4 @@ class RetinanetModel(base_model.Model):
return
labels
,
outputs
return
labels
,
outputs
def
eval_metrics
(
self
):
def
eval_metrics
(
self
):
return
COCOMetrics
(
self
.
_params
)
return
eval_factory
.
evaluator_generator
(
self
.
_params
.
eval
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment