Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bd488858
Commit
bd488858
authored
Mar 20, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #8302 from ayushmankumar7:absl
PiperOrigin-RevId: 302043775
parents
2416dd9c
55bf4b80
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
18 deletions
+24
-18
official/utils/logs/logger_test.py
official/utils/logs/logger_test.py
+5
-4
official/utils/logs/mlperf_helper.py
official/utils/logs/mlperf_helper.py
+6
-6
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+3
-1
official/utils/misc/model_helpers.py
official/utils/misc/model_helpers.py
+7
-5
official/utils/testing/perfzero_benchmark.py
official/utils/testing/perfzero_benchmark.py
+3
-2
No files found.
official/utils/logs/logger_test.py
View file @
bd488858
...
@@ -28,6 +28,7 @@ import unittest
...
@@ -28,6 +28,7 @@ import unittest
import
mock
import
mock
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
try
:
try
:
from
google.cloud
import
bigquery
from
google.cloud
import
bigquery
...
@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
...
@@ -79,7 +80,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger
=
mock
.
MagicMock
()
mock_logger
=
mock
.
MagicMock
()
mock_config_benchmark_logger
.
return_value
=
mock_logger
mock_config_benchmark_logger
.
return_value
=
mock_logger
with
logger
.
benchmark_context
(
None
):
with
logger
.
benchmark_context
(
None
):
tf
.
compat
.
v1
.
logging
.
info
(
"start benchmarking"
)
logging
.
info
(
"start benchmarking"
)
mock_logger
.
on_finish
.
assert_called_once_with
(
logger
.
RUN_STATUS_SUCCESS
)
mock_logger
.
on_finish
.
assert_called_once_with
(
logger
.
RUN_STATUS_SUCCESS
)
@
mock
.
patch
(
"official.utils.logs.logger.config_benchmark_logger"
)
@
mock
.
patch
(
"official.utils.logs.logger.config_benchmark_logger"
)
...
@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
...
@@ -96,18 +97,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
setUp
()
super
(
BaseBenchmarkLoggerTest
,
self
).
setUp
()
self
.
_actual_log
=
tf
.
compat
.
v1
.
logging
.
info
self
.
_actual_log
=
logging
.
info
self
.
logged_message
=
None
self
.
logged_message
=
None
def
mock_log
(
*
args
,
**
kwargs
):
def
mock_log
(
*
args
,
**
kwargs
):
self
.
logged_message
=
args
self
.
logged_message
=
args
self
.
_actual_log
(
*
args
,
**
kwargs
)
self
.
_actual_log
(
*
args
,
**
kwargs
)
tf
.
compat
.
v1
.
logging
.
info
=
mock_log
logging
.
info
=
mock_log
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
tearDown
()
super
(
BaseBenchmarkLoggerTest
,
self
).
tearDown
()
tf
.
compat
.
v1
.
logging
.
info
=
self
.
_actual_log
logging
.
info
=
self
.
_actual_log
def
test_log_metric
(
self
):
def
test_log_metric
(
self
):
log
=
logger
.
BaseBenchmarkLogger
()
log
=
logger
.
BaseBenchmarkLogger
()
...
...
official/utils/logs/mlperf_helper.py
View file @
bd488858
...
@@ -31,8 +31,9 @@ import re
...
@@ -31,8 +31,9 @@ import re
import
subprocess
import
subprocess
import
sys
import
sys
import
typing
import
typing
from
absl
import
logging
# pylint:disable=logging-format-interpolation
import
tensorflow
as
tf
_MIN_VERSION
=
(
0
,
0
,
10
)
_MIN_VERSION
=
(
0
,
0
,
10
)
_STACK_OFFSET
=
2
_STACK_OFFSET
=
2
...
@@ -94,8 +95,7 @@ def get_mlperf_log():
...
@@ -94,8 +95,7 @@ def get_mlperf_log():
version
=
pkg_resources
.
get_distribution
(
"mlperf_compliance"
)
version
=
pkg_resources
.
get_distribution
(
"mlperf_compliance"
)
version
=
tuple
(
int
(
i
)
for
i
in
version
.
version
.
split
(
"."
))
version
=
tuple
(
int
(
i
)
for
i
in
version
.
version
.
split
(
"."
))
if
version
<
_MIN_VERSION
:
if
version
<
_MIN_VERSION
:
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"mlperf_compliance is version {}, must be >= {}"
.
format
(
"mlperf_compliance is version {}, must be >= {}"
.
format
(
"."
.
join
([
str
(
i
)
for
i
in
version
]),
"."
.
join
([
str
(
i
)
for
i
in
version
]),
"."
.
join
([
str
(
i
)
for
i
in
_MIN_VERSION
])))
"."
.
join
([
str
(
i
)
for
i
in
_MIN_VERSION
])))
raise
ImportError
raise
ImportError
...
@@ -187,6 +187,6 @@ def clear_system_caches():
...
@@ -187,6 +187,6 @@ def clear_system_caches():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
with
LOGGER
(
True
):
with
LOGGER
(
True
):
ncf_print
(
key
=
TAGS
.
RUN_START
)
ncf_print
(
key
=
TAGS
.
RUN_START
)
official/utils/misc/distribution_utils.py
View file @
bd488858
...
@@ -22,6 +22,8 @@ import json
...
@@ -22,6 +22,8 @@ import json
import
os
import
os
import
random
import
random
import
string
import
string
from
absl
import
logging
import
tensorflow.compat.v2
as
tf
import
tensorflow.compat.v2
as
tf
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
tpu_lib
...
@@ -252,7 +254,7 @@ class SyntheticIterator(object):
...
@@ -252,7 +254,7 @@ class SyntheticIterator(object):
def
_monkey_patch_dataset_method
(
strategy
):
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
(
self
,
dataset
):
def
make_dataset
(
self
,
dataset
):
tf
.
compat
.
v1
.
logging
.
info
(
'Using pure synthetic data.'
)
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
return
SyntheticDataset
(
dataset
,
self
.
num_replicas_in_sync
)
return
SyntheticDataset
(
dataset
,
self
.
num_replicas_in_sync
)
...
...
official/utils/misc/model_helpers.py
View file @
bd488858
...
@@ -20,8 +20,11 @@ from __future__ import print_function
...
@@ -20,8 +20,11 @@ from __future__ import print_function
import
numbers
import
numbers
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.util
import
nest
from
tensorflow.python.util
import
nest
# pylint:disable=logging-format-interpolation
def
past_stop_threshold
(
stop_threshold
,
eval_metric
):
def
past_stop_threshold
(
stop_threshold
,
eval_metric
):
...
@@ -48,8 +51,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
...
@@ -48,8 +51,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number."
)
"must be a number."
)
if
eval_metric
>=
stop_threshold
:
if
eval_metric
>=
stop_threshold
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
stop_threshold
,
eval_metric
))
stop_threshold
,
eval_metric
))
return
True
return
True
...
@@ -88,6 +90,6 @@ def generate_synthetic_data(
...
@@ -88,6 +90,6 @@ def generate_synthetic_data(
def
apply_clean
(
flags_obj
):
def
apply_clean
(
flags_obj
):
if
flags_obj
.
clean
and
tf
.
io
.
gfile
.
exists
(
flags_obj
.
model_dir
):
if
flags_obj
.
clean
and
tf
.
io
.
gfile
.
exists
(
flags_obj
.
model_dir
):
tf
.
compat
.
v1
.
logging
.
info
(
"--clean flag set. Removing existing model dir:"
logging
.
info
(
"--clean flag set. Removing existing model dir:"
" {}"
.
format
(
flags_obj
.
model_dir
))
" {}"
.
format
(
flags_obj
.
model_dir
))
tf
.
io
.
gfile
.
rmtree
(
flags_obj
.
model_dir
)
tf
.
io
.
gfile
.
rmtree
(
flags_obj
.
model_dir
)
official/utils/testing/perfzero_benchmark.py
View file @
bd488858
...
@@ -20,8 +20,9 @@ from __future__ import print_function
...
@@ -20,8 +20,9 @@ from __future__ import print_function
import
os
import
os
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -75,7 +76,7 @@ class PerfZeroBenchmark(tf.test.Benchmark):
...
@@ -75,7 +76,7 @@ class PerfZeroBenchmark(tf.test.Benchmark):
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
PerfZeroBenchmark
.
local_flags
is
None
:
if
PerfZeroBenchmark
.
local_flags
is
None
:
for
flag_method
in
self
.
flag_methods
:
for
flag_method
in
self
.
flag_methods
:
flag_method
()
flag_method
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment