Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dfa1a569
Commit
dfa1a569
authored
Oct 04, 2021
by
Vishnu Banna
Browse files
model task
parent
cac293d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
38 deletions
+40
-38
official/vision/beta/projects/yolo/tasks/task_utils.py
official/vision/beta/projects/yolo/tasks/task_utils.py
+37
-0
official/vision/beta/projects/yolo/tasks/yolo.py
official/vision/beta/projects/yolo/tasks/yolo.py
+3
-38
No files found.
official/vision/beta/projects/yolo/tasks/task_utils.py
0 → 100644
View file @
dfa1a569
import
tensorflow
as
tf
class
ListMetrics
:
"""Private class used to cleanly place the matric values for each level."""
def
__init__
(
self
,
metric_names
,
name
=
"ListMetrics"
,
**
kwargs
):
self
.
name
=
name
self
.
_metric_names
=
metric_names
self
.
_metrics
=
self
.
build_metric
()
return
def
build_metric
(
self
):
metric_names
=
self
.
_metric_names
metrics
=
[]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
return
metrics
def
update_state
(
self
,
loss_metrics
):
metrics
=
self
.
_metrics
for
m
in
metrics
:
m
.
update_state
(
loss_metrics
[
m
.
name
])
return
def
result
(
self
):
logs
=
dict
()
metrics
=
self
.
_metrics
for
m
in
metrics
:
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
reset_states
(
self
):
metrics
=
self
.
_metrics
for
m
in
metrics
:
m
.
reset_states
()
return
\ No newline at end of file
official/vision/beta/projects/yolo/tasks/yolo.py
View file @
dfa1a569
...
@@ -33,6 +33,7 @@ from official.vision.beta.projects.yolo.ops import preprocessing_ops
...
@@ -33,6 +33,7 @@ from official.vision.beta.projects.yolo.ops import preprocessing_ops
from
official.vision.beta.projects.yolo.dataloaders
import
yolo_input
from
official.vision.beta.projects.yolo.dataloaders
import
yolo_input
from
official.vision.beta.projects.yolo.dataloaders
import
tf_example_decoder
from
official.vision.beta.projects.yolo.dataloaders
import
tf_example_decoder
from
official.vision.beta.projects.yolo.configs
import
yolo
as
exp_cfg
from
official.vision.beta.projects.yolo.configs
import
yolo
as
exp_cfg
from
official.vision.beta.projects.yolo.tasks
import
task_utils
import
tensorflow
as
tf
import
tensorflow
as
tf
from
typing
import
Optional
from
typing
import
Optional
...
@@ -185,7 +186,7 @@ class YoloTask(base_task.Task):
...
@@ -185,7 +186,7 @@ class YoloTask(base_task.Task):
metric_names
[
'net'
].
append
(
'conf'
)
metric_names
[
'net'
].
append
(
'conf'
)
for
i
,
key
in
enumerate
(
metric_names
.
keys
()):
for
i
,
key
in
enumerate
(
metric_names
.
keys
()):
metrics
.
append
(
_
ListMetrics
(
metric_names
[
key
],
name
=
key
))
metrics
.
append
(
task_utils
.
ListMetrics
(
metric_names
[
key
],
name
=
key
))
self
.
_metrics
=
metrics
self
.
_metrics
=
metrics
if
not
training
:
if
not
training
:
...
@@ -395,40 +396,4 @@ class YoloTask(base_task.Task):
...
@@ -395,40 +396,4 @@ class YoloTask(base_task.Task):
use_float16
=
use_float16
,
use_float16
=
use_float16
,
loss_scale
=
runtime_config
.
loss_scale
)
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
return
optimizer
\ No newline at end of file
class
_ListMetrics
:
"""Private class used to cleanly place the matric values for each level."""
def
__init__
(
self
,
metric_names
,
name
=
"ListMetrics"
,
**
kwargs
):
self
.
name
=
name
self
.
_metric_names
=
metric_names
self
.
_metrics
=
self
.
build_metric
()
return
def
build_metric
(
self
):
metric_names
=
self
.
_metric_names
metrics
=
[]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
return
metrics
def
update_state
(
self
,
loss_metrics
):
metrics
=
self
.
_metrics
for
m
in
metrics
:
m
.
update_state
(
loss_metrics
[
m
.
name
])
return
def
result
(
self
):
logs
=
dict
()
metrics
=
self
.
_metrics
for
m
in
metrics
:
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
reset_states
(
self
):
metrics
=
self
.
_metrics
for
m
in
metrics
:
m
.
reset_states
()
return
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment