Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
041f6976
Commit
041f6976
authored
Nov 10, 2020
by
Pengchong Jin
Committed by
A. Unique TensorFlower
Nov 10, 2020
Browse files
Internal change
PiperOrigin-RevId: 341687182
parent
d2501e46
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
2 deletions
+10
-2
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+6
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+4
-2
No files found.
official/vision/beta/configs/image_classification.py
View file @
041f6976
...
...
@@ -57,6 +57,11 @@ class Losses(hyperparams.Config):
l2_weight_decay
:
float
=
0.0
@
dataclasses
.
dataclass
class
Evaluation
(
hyperparams
.
Config
):
top_k
:
int
=
5
@
dataclasses
.
dataclass
class
ImageClassificationTask
(
cfg
.
TaskConfig
):
"""The task config."""
...
...
@@ -64,6 +69,7 @@ class ImageClassificationTask(cfg.TaskConfig):
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
evaluation
:
Evaluation
=
Evaluation
()
gradient_clip_norm
:
float
=
0.0
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
str
=
'all'
# all or backbone
...
...
official/vision/beta/tasks/image_classification.py
View file @
041f6976
...
...
@@ -123,15 +123,17 @@ class ImageClassificationTask(base_task.Task):
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation."""
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
losses
.
one_hot
:
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)]
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_
5
_accuracy'
)]
k
=
k
,
name
=
'top_
{}
_accuracy'
.
format
(
k
)
)]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment