Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4bd2888b
Commit
4bd2888b
authored
Jul 19, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jul 19, 2021
Browse files
Add FLOPs computation into run_experiment.
PiperOrigin-RevId: 385674527
parent
f93bea86
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
1 deletion
+54
-1
official/core/train_lib.py
official/core/train_lib.py
+5
-0
official/core/train_utils.py
official/core/train_utils.py
+49
-1
No files found.
official/core/train_lib.py
View file @
4bd2888b
...
...
@@ -138,6 +138,11 @@ def run_experiment(
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
run_post_eval
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
trainer
.
evaluate
(
...
...
official/core/train_utils.py
View file @
4bd2888b
...
...
@@ -17,7 +17,7 @@ import copy
import
json
import
os
import
pprint
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
absl
import
logging
import
dataclasses
...
...
@@ -25,6 +25,9 @@ import gin
import
orbit
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.framework.convert_to_constants
import
convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
...
...
@@ -393,3 +396,48 @@ def try_count_params(model: tf.keras.Model):
'train step already reached before this run.'
)
return
None
return
None
def
try_count_flops
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
inputs_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if
hasattr
(
model
,
'inputs'
):
try
:
# Get input shape and set batch size to 1.
if
model
.
inputs
:
inputs
=
[
tf
.
TensorSpec
([
1
]
+
input
.
shape
[
1
:],
input
.
dtype
)
for
input
in
model
.
inputs
]
concrete_func
=
tf
.
function
(
model
).
get_concrete_function
(
inputs
)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else
:
concrete_func
=
tf
.
function
(
model
.
call
).
get_concrete_function
(
**
inputs_kwargs
)
frozen_func
,
_
=
convert_variables_to_constants_v2_as_graph
(
concrete_func
)
# Calculate FLOPs.
run_meta
=
tf
.
compat
.
v1
.
RunMetadata
()
opts
=
tf
.
compat
.
v1
.
profiler
.
ProfileOptionBuilder
.
float_operation
()
opts
[
'output'
]
=
'none'
flops
=
tf
.
compat
.
v1
.
profiler
.
profile
(
graph
=
frozen_func
.
graph
,
run_meta
=
run_meta
,
options
=
opts
)
return
flops
.
total_float_ops
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
info
(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.'
,
e
)
return
None
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment