Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e4748866
Commit
e4748866
authored
Jun 27, 2022
by
Denali Molitor
Committed by
A. Unique TensorFlower
Jun 27, 2022
Browse files
Internal change
PiperOrigin-RevId: 457601447
parent
ca6d7c57
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
5 deletions
+22
-5
official/benchmark/benchmark_lib.py
official/benchmark/benchmark_lib.py
+21
-2
official/benchmark/benchmark_lib_test.py
official/benchmark/benchmark_lib_test.py
+1
-3
No files found.
official/benchmark/benchmark_lib.py
View file @
e4748866
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""TFM common benchmark training driver."""
"""TFM common benchmark training driver."""
import
os
import
os
import
time
import
time
from
typing
import
Any
,
Mapping
from
typing
import
Any
,
Mapping
,
Optional
from
absl
import
logging
from
absl
import
logging
import
orbit
import
orbit
...
@@ -29,6 +29,19 @@ from official.modeling import performance
...
@@ -29,6 +29,19 @@ from official.modeling import performance
from
official.projects.token_dropping
import
experiment_configs
# pylint: disable=unused-import
from
official.projects.token_dropping
import
experiment_configs
# pylint: disable=unused-import
class
_OutputRecorderAction
:
"""Simple `Action` that saves the outputs passed to `__call__`."""
def
__init__
(
self
):
self
.
train_output
=
{}
def
__call__
(
self
,
output
:
Optional
[
Mapping
[
str
,
tf
.
Tensor
]]
=
None
)
->
Mapping
[
str
,
Any
]:
self
.
train_output
=
{
k
:
v
.
numpy
()
for
k
,
v
in
output
.
items
()
}
if
output
else
{}
def
run_benchmark
(
def
run_benchmark
(
execution_mode
:
str
,
execution_mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
params
:
config_definitions
.
ExperimentConfig
,
...
@@ -82,10 +95,13 @@ def run_benchmark(
...
@@ -82,10 +95,13 @@ def run_benchmark(
steps_per_loop
=
params
.
trainer
.
steps_per_loop
if
(
steps_per_loop
=
params
.
trainer
.
steps_per_loop
if
(
execution_mode
in
[
'accuracy'
,
'tflite_accuracy'
])
else
100
execution_mode
in
[
'accuracy'
,
'tflite_accuracy'
])
else
100
train_output_recorder
=
_OutputRecorderAction
()
controller
=
orbit
.
Controller
(
controller
=
orbit
.
Controller
(
strategy
=
strategy
,
strategy
=
strategy
,
trainer
=
trainer
,
trainer
=
trainer
,
evaluator
=
trainer
if
(
execution_mode
==
'accuracy'
)
else
None
,
evaluator
=
trainer
if
(
execution_mode
==
'accuracy'
)
else
None
,
train_actions
=
[
train_output_recorder
],
global_step
=
trainer
.
global_step
,
global_step
=
trainer
.
global_step
,
steps_per_loop
=
steps_per_loop
)
steps_per_loop
=
steps_per_loop
)
...
@@ -108,6 +124,9 @@ def run_benchmark(
...
@@ -108,6 +124,9 @@ def run_benchmark(
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
))
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
))
benchmark_data
=
{
'metrics'
:
eval_logs
}
benchmark_data
=
{
'metrics'
:
eval_logs
}
elif
execution_mode
==
'performance'
:
elif
execution_mode
==
'performance'
:
if
train_output_recorder
.
train_output
:
benchmark_data
=
{
'metrics'
:
train_output_recorder
.
train_output
}
else
:
benchmark_data
=
{}
benchmark_data
=
{}
elif
execution_mode
==
'tflite_accuracy'
:
elif
execution_mode
==
'tflite_accuracy'
:
eval_logs
=
tflite_utils
.
train_and_evaluate
(
eval_logs
=
tflite_utils
.
train_and_evaluate
(
...
...
official/benchmark/benchmark_lib_test.py
View file @
e4748866
...
@@ -80,8 +80,6 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -80,8 +80,6 @@ class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
'examples_per_second'
,
benchmark_data
)
self
.
assertIn
(
'examples_per_second'
,
benchmark_data
)
self
.
assertIn
(
'wall_time'
,
benchmark_data
)
self
.
assertIn
(
'wall_time'
,
benchmark_data
)
self
.
assertIn
(
'startup_time'
,
benchmark_data
)
self
.
assertIn
(
'startup_time'
,
benchmark_data
)
if
execution_mode
==
'accuracy'
:
self
.
assertIn
(
'metrics'
,
benchmark_data
)
self
.
assertIn
(
'metrics'
,
benchmark_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment