Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bec89e3a
Commit
bec89e3a
authored
May 31, 2022
by
Fan Yang
Committed by
A. Unique TensorFlower
May 31, 2022
Browse files
Add per-class precision and recall for image classification task.
PiperOrigin-RevId: 452075537
parent
3fa89ace
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
1 deletion
+23
-1
official/vision/configs/image_classification.py
official/vision/configs/image_classification.py
+1
-0
official/vision/tasks/image_classification.py
official/vision/tasks/image_classification.py
+22
-1
No files found.
official/vision/configs/image_classification.py
View file @
bec89e3a
...
@@ -82,6 +82,7 @@ class Losses(hyperparams.Config):
...
@@ -82,6 +82,7 @@ class Losses(hyperparams.Config):
class
Evaluation
(
hyperparams
.
Config
):
class
Evaluation
(
hyperparams
.
Config
):
top_k
:
int
=
5
top_k
:
int
=
5
precision_and_recall_thresholds
:
Optional
[
List
[
float
]]
=
None
precision_and_recall_thresholds
:
Optional
[
List
[
float
]]
=
None
report_per_class_precision_and_recall
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/tasks/image_classification.py
View file @
bec89e3a
...
@@ -201,7 +201,28 @@ class ImageClassificationTask(base_task.Task):
...
@@ -201,7 +201,28 @@ class ImageClassificationTask(base_task.Task):
name
=
'recall_at_threshold_{}'
.
format
(
th
),
name
=
'recall_at_threshold_{}'
.
format
(
th
),
top_k
=
1
)
for
th
in
thresholds
top_k
=
1
)
for
th
in
thresholds
]
]
# pylint:enable=g-complex-comprehension
# Add per-class precision and recall.
if
hasattr
(
self
.
task_config
.
evaluation
,
'report_per_class_precision_and_recall'
)
and
self
.
task_config
.
evaluation
.
report_per_class_precision_and_recall
:
for
class_id
in
range
(
self
.
task_config
.
model
.
num_classes
):
metrics
+=
[
tf
.
keras
.
metrics
.
Precision
(
thresholds
=
th
,
class_id
=
class_id
,
name
=
f
'precision_at_threshold_
{
th
}
/
{
class_id
}
'
,
top_k
=
1
)
for
th
in
thresholds
]
metrics
+=
[
tf
.
keras
.
metrics
.
Recall
(
thresholds
=
th
,
class_id
=
class_id
,
name
=
f
'recall_at_threshold_
{
th
}
/
{
class_id
}
'
,
top_k
=
1
)
for
th
in
thresholds
]
# pylint:enable=g-complex-comprehension
else
:
else
:
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment