Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
695c9d58
Commit
695c9d58
authored
Oct 01, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 334910093
parent
52979660
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
3 deletions
+54
-3
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+1
-0
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+1
-0
official/vision/beta/evaluation/coco_evaluator.py
official/vision/beta/evaluation/coco_evaluator.py
+47
-1
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+2
-1
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+3
-1
No files found.
official/vision/beta/configs/maskrcnn.py
View file @
695c9d58
...
@@ -208,6 +208,7 @@ class MaskRCNNTask(cfg.TaskConfig):
...
@@ -208,6 +208,7 @@ class MaskRCNNTask(cfg.TaskConfig):
init_checkpoint_modules
:
str
=
'all'
# all or backbone
init_checkpoint_modules
:
str
=
'all'
# all or backbone
annotation_file
:
Optional
[
str
]
=
None
annotation_file
:
Optional
[
str
]
=
None
gradient_clip_norm
:
float
=
0.0
gradient_clip_norm
:
float
=
0.0
per_category_metrics
=
False
COCO_INPUT_PATH_BASE
=
'coco'
COCO_INPUT_PATH_BASE
=
'coco'
...
...
official/vision/beta/configs/retinanet.py
View file @
695c9d58
...
@@ -129,6 +129,7 @@ class RetinaNetTask(cfg.TaskConfig):
...
@@ -129,6 +129,7 @@ class RetinaNetTask(cfg.TaskConfig):
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
str
=
'all'
# all or backbone
init_checkpoint_modules
:
str
=
'all'
# all or backbone
gradient_clip_norm
:
float
=
0.0
gradient_clip_norm
:
float
=
0.0
per_category_metrics
=
False
@
exp_factory
.
register_config_factory
(
'retinanet'
)
@
exp_factory
.
register_config_factory
(
'retinanet'
)
...
...
official/vision/beta/evaluation/coco_evaluator.py
View file @
695c9d58
...
@@ -41,7 +41,11 @@ from official.vision.beta.evaluation import coco_utils
...
@@ -41,7 +41,11 @@ from official.vision.beta.evaluation import coco_utils
class
COCOEvaluator
(
object
):
class
COCOEvaluator
(
object
):
"""COCO evaluation metric class."""
"""COCO evaluation metric class."""
def
__init__
(
self
,
annotation_file
,
include_mask
,
need_rescale_bboxes
=
True
):
def
__init__
(
self
,
annotation_file
,
include_mask
,
need_rescale_bboxes
=
True
,
per_category_metrics
=
False
):
"""Constructs COCO evaluation class.
"""Constructs COCO evaluation class.
The class provides the interface to COCO metrics_fn. The
The class provides the interface to COCO metrics_fn. The
...
@@ -57,6 +61,7 @@ class COCOEvaluator(object):
...
@@ -57,6 +61,7 @@ class COCOEvaluator(object):
eval.
eval.
need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
to absolute values (`image_info` is needed in this case).
to absolute values (`image_info` is needed in this case).
per_category_metrics: Whether to return per category metrics.
"""
"""
if
annotation_file
:
if
annotation_file
:
if
annotation_file
.
startswith
(
'gs://'
):
if
annotation_file
.
startswith
(
'gs://'
):
...
@@ -72,6 +77,7 @@ class COCOEvaluator(object):
...
@@ -72,6 +77,7 @@ class COCOEvaluator(object):
annotation_file
=
local_val_json
)
annotation_file
=
local_val_json
)
self
.
_annotation_file
=
annotation_file
self
.
_annotation_file
=
annotation_file
self
.
_include_mask
=
include_mask
self
.
_include_mask
=
include_mask
self
.
_per_category_metrics
=
per_category_metrics
self
.
_metric_names
=
[
self
.
_metric_names
=
[
'AP'
,
'AP50'
,
'AP75'
,
'APs'
,
'APm'
,
'APl'
,
'ARmax1'
,
'ARmax10'
,
'AP'
,
'AP50'
,
'AP75'
,
'APs'
,
'APm'
,
'APl'
,
'ARmax1'
,
'ARmax10'
,
'ARmax100'
,
'ARs'
,
'ARm'
,
'ARl'
'ARmax100'
,
'ARs'
,
'ARm'
,
'ARl'
...
@@ -156,6 +162,46 @@ class COCOEvaluator(object):
...
@@ -156,6 +162,46 @@ class COCOEvaluator(object):
metrics_dict
=
{}
metrics_dict
=
{}
for
i
,
name
in
enumerate
(
self
.
_metric_names
):
for
i
,
name
in
enumerate
(
self
.
_metric_names
):
metrics_dict
[
name
]
=
metrics
[
i
].
astype
(
np
.
float32
)
metrics_dict
[
name
]
=
metrics
[
i
].
astype
(
np
.
float32
)
# Adds metrics per category.
if
self
.
_per_category_metrics
and
hasattr
(
coco_eval
,
'category_stats'
):
for
category_index
,
category_id
in
enumerate
(
coco_eval
.
params
.
catIds
):
metrics_dict
[
'Precision mAP ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
0
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Precision mAP ByCategory@50IoU/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
1
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Precision mAP ByCategory@75IoU/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
2
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Precision mAP ByCategory (small) /{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
3
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Precision mAP ByCategory (medium) /{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
4
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Precision mAP ByCategory (large) /{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
5
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Recall AR@1 ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
6
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Recall AR@10 ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
7
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Recall AR@100 ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
8
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Recall AR (small) ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
9
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Recall AR (medium) ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
10
][
category_index
].
astype
(
np
.
float32
)
metrics_dict
[
'Recall AR (large) ByCategory/{}'
.
format
(
category_id
)]
=
coco_eval
.
category_stats
[
11
][
category_index
].
astype
(
np
.
float32
)
return
metrics_dict
return
metrics_dict
def
_process_predictions
(
self
,
predictions
):
def
_process_predictions
(
self
,
predictions
):
...
...
official/vision/beta/tasks/maskrcnn.py
View file @
695c9d58
...
@@ -204,7 +204,8 @@ class MaskRCNNTask(base_task.Task):
...
@@ -204,7 +204,8 @@ class MaskRCNNTask(base_task.Task):
else
:
else
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
self
.
_task_config
.
annotation_file
,
annotation_file
=
self
.
_task_config
.
annotation_file
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
)
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
return
metrics
return
metrics
...
...
official/vision/beta/tasks/retinanet.py
View file @
695c9d58
...
@@ -178,7 +178,9 @@ class RetinaNetTask(base_task.Task):
...
@@ -178,7 +178,9 @@ class RetinaNetTask(base_task.Task):
if
not
training
:
if
not
training
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
None
,
include_mask
=
False
)
annotation_file
=
self
.
_task_config
.
annotation_file
,
include_mask
=
False
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
return
metrics
return
metrics
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment