Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bd488858
Commit
bd488858
authored
Mar 20, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #8302 from ayushmankumar7:absl
PiperOrigin-RevId: 302043775
parents
2416dd9c
55bf4b80
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
96 additions
and
92 deletions
+96
-92
official/benchmark/models/resnet_cifar_main.py
official/benchmark/models/resnet_cifar_main.py
+3
-2
official/benchmark/ncf_keras_benchmark.py
official/benchmark/ncf_keras_benchmark.py
+2
-2
official/modeling/model_training_utils_test.py
official/modeling/model_training_utils_test.py
+2
-1
official/nlp/transformer/translate.py
official/nlp/transformer/translate.py
+7
-8
official/nlp/transformer/utils/tokenizer.py
official/nlp/transformer/utils/tokenizer.py
+15
-17
official/r1/mnist/mnist.py
official/r1/mnist/mnist.py
+3
-2
official/r1/mnist/mnist_test.py
official/r1/mnist/mnist_test.py
+2
-2
official/r1/resnet/cifar10_main.py
official/r1/resnet/cifar10_main.py
+7
-6
official/r1/resnet/cifar10_test.py
official/r1/resnet/cifar10_test.py
+3
-2
official/r1/resnet/estimator_benchmark.py
official/r1/resnet/estimator_benchmark.py
+2
-1
official/r1/resnet/imagenet_main.py
official/r1/resnet/imagenet_main.py
+5
-4
official/r1/resnet/imagenet_test.py
official/r1/resnet/imagenet_test.py
+2
-1
official/r1/resnet/resnet_run_loop.py
official/r1/resnet/resnet_run_loop.py
+9
-11
official/r1/utils/data/file_io.py
official/r1/utils/data/file_io.py
+9
-9
official/r1/wide_deep/census_test.py
official/r1/wide_deep/census_test.py
+2
-1
official/r1/wide_deep/movielens_test.py
official/r1/wide_deep/movielens_test.py
+2
-1
official/utils/flags/_device.py
official/utils/flags/_device.py
+2
-2
official/utils/logs/hooks_helper.py
official/utils/logs/hooks_helper.py
+4
-3
official/utils/logs/hooks_test.py
official/utils/logs/hooks_test.py
+2
-1
official/utils/logs/logger.py
official/utils/logs/logger.py
+13
-16
No files found.
official/benchmark/models/resnet_cifar_main.py
View file @
bd488858
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.benchmark.models
import
resnet_cifar_model
from
official.benchmark.models
import
resnet_cifar_model
...
@@ -100,7 +101,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
...
@@ -100,7 +101,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
if
lr
!=
self
.
prev_lr
:
if
lr
!=
self
.
prev_lr
:
self
.
model
.
optimizer
.
learning_rate
=
lr
# lr should be a float here
self
.
model
.
optimizer
.
learning_rate
=
lr
# lr should be a float here
self
.
prev_lr
=
lr
self
.
prev_lr
=
lr
tf
.
compat
.
v1
.
logging
.
debug
(
logging
.
debug
(
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'Epoch %05d Batch %05d: LearningRateBatchScheduler '
'change learning rate to %s.'
,
self
.
epochs
,
batch
,
lr
)
'change learning rate to %s.'
,
self
.
epochs
,
batch
,
lr
)
...
@@ -280,6 +281,6 @@ def main(_):
...
@@ -280,6 +281,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_cifar_flags
()
define_cifar_flags
()
app
.
run
(
main
)
app
.
run
(
main
)
official/benchmark/ncf_keras_benchmark.py
View file @
bd488858
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""
"""Executes Keras benchmarks and accuracy tests."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
@@ -22,6 +21,7 @@ import os
...
@@ -22,6 +21,7 @@ import os
import
time
import
time
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -51,7 +51,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
...
@@ -51,7 +51,7 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
"""Sets up and resets flags before each test."""
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
NCFKerasBenchmarkBase
.
local_flags
is
None
:
if
NCFKerasBenchmarkBase
.
local_flags
is
None
:
ncf_common
.
define_ncf_flags
()
ncf_common
.
define_ncf_flags
()
# Loads flags to get defaults to then override. List cannot be empty.
# Loads flags to get defaults to then override. List cannot be empty.
...
...
official/modeling/model_training_utils_test.py
View file @
bd488858
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
os
import
os
from
absl
import
logging
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
from
absl.testing.absltest
import
mock
from
absl.testing.absltest
import
mock
import
numpy
as
np
import
numpy
as
np
...
@@ -125,7 +126,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
...
@@ -125,7 +126,7 @@ def summaries_with_matching_keyword(keyword, summary_dir):
if
event
.
summary
is
not
None
:
if
event
.
summary
is
not
None
:
for
value
in
event
.
summary
.
value
:
for
value
in
event
.
summary
.
value
:
if
keyword
in
value
.
tag
:
if
keyword
in
value
.
tag
:
tf
.
compat
.
v1
.
logging
.
error
(
event
)
logging
.
error
(
event
)
yield
event
.
summary
yield
event
.
summary
...
...
official/nlp/transformer/translate.py
View file @
bd488858
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -117,8 +118,7 @@ def translate_file(model,
...
@@ -117,8 +118,7 @@ def translate_file(model,
maxlen
=
params
[
"decode_max_length"
],
maxlen
=
params
[
"decode_max_length"
],
dtype
=
"int32"
,
dtype
=
"int32"
,
padding
=
"post"
)
padding
=
"post"
)
tf
.
compat
.
v1
.
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
num_decode_batches
)
num_decode_batches
)
yield
batch
yield
batch
@
tf
.
function
@
tf
.
function
...
@@ -172,16 +172,15 @@ def translate_file(model,
...
@@ -172,16 +172,15 @@ def translate_file(model,
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
translations
.
append
(
translation
)
translations
.
append
(
translation
)
if
print_all_translations
:
if
print_all_translations
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
,
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
%
sorted_inputs
[
j
+
i
*
batch_size
],
translation
)
(
sorted_inputs
[
j
+
i
*
batch_size
],
translation
))
# Write translations in the order they appeared in the original file.
# Write translations in the order they appeared in the original file.
if
output_file
is
not
None
:
if
output_file
is
not
None
:
if
tf
.
io
.
gfile
.
isdir
(
output_file
):
if
tf
.
io
.
gfile
.
isdir
(
output_file
):
raise
ValueError
(
"File output is a directory, will not save outputs to "
raise
ValueError
(
"File output is a directory, will not save outputs to "
"file."
)
"file."
)
tf
.
compat
.
v1
.
logging
.
info
(
"Writing to file %s"
%
output_file
)
logging
.
info
(
"Writing to file %s"
,
output_file
)
with
tf
.
compat
.
v1
.
gfile
.
Open
(
output_file
,
"w"
)
as
f
:
with
tf
.
compat
.
v1
.
gfile
.
Open
(
output_file
,
"w"
)
as
f
:
for
i
in
sorted_keys
:
for
i
in
sorted_keys
:
f
.
write
(
"%s
\n
"
%
translations
[
i
])
f
.
write
(
"%s
\n
"
%
translations
[
i
])
...
@@ -191,10 +190,10 @@ def translate_from_text(model, subtokenizer, txt):
...
@@ -191,10 +190,10 @@ def translate_from_text(model, subtokenizer, txt):
encoded_txt
=
_encode_and_add_eos
(
txt
,
subtokenizer
)
encoded_txt
=
_encode_and_add_eos
(
txt
,
subtokenizer
)
result
=
model
.
predict
(
encoded_txt
)
result
=
model
.
predict
(
encoded_txt
)
outputs
=
result
[
"outputs"
]
outputs
=
result
[
"outputs"
]
tf
.
compat
.
v1
.
logging
.
info
(
"Original:
\"
%s
\"
"
%
txt
)
logging
.
info
(
"Original:
\"
%s
\"
"
,
txt
)
translate_from_input
(
outputs
,
subtokenizer
)
translate_from_input
(
outputs
,
subtokenizer
)
def
translate_from_input
(
outputs
,
subtokenizer
):
def
translate_from_input
(
outputs
,
subtokenizer
):
translation
=
_trim_and_decode
(
outputs
,
subtokenizer
)
translation
=
_trim_and_decode
(
outputs
,
subtokenizer
)
tf
.
compat
.
v1
.
logging
.
info
(
"Translation:
\"
%s
\"
"
%
translation
)
logging
.
info
(
"Translation:
\"
%s
\"
"
,
translation
)
official/nlp/transformer/utils/tokenizer.py
View file @
bd488858
...
@@ -22,6 +22,7 @@ import collections
...
@@ -22,6 +22,7 @@ import collections
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
six
import
six
...
@@ -71,8 +72,7 @@ class Subtokenizer(object):
...
@@ -71,8 +72,7 @@ class Subtokenizer(object):
def
__init__
(
self
,
vocab_file
,
reserved_tokens
=
None
,
master_char_set
=
None
):
def
__init__
(
self
,
vocab_file
,
reserved_tokens
=
None
,
master_char_set
=
None
):
"""Initializes class, creating a vocab file if data_files is provided."""
"""Initializes class, creating a vocab file if data_files is provided."""
tf
.
compat
.
v1
.
logging
.
info
(
"Initializing Subtokenizer from file %s."
%
logging
.
info
(
"Initializing Subtokenizer from file %s."
,
vocab_file
)
vocab_file
)
if
master_char_set
is
None
:
if
master_char_set
is
None
:
master_char_set
=
_ALPHANUMERIC_CHAR_SET
master_char_set
=
_ALPHANUMERIC_CHAR_SET
...
@@ -130,17 +130,17 @@ class Subtokenizer(object):
...
@@ -130,17 +130,17 @@ class Subtokenizer(object):
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
if
tf
.
io
.
gfile
.
exists
(
vocab_file
):
if
tf
.
io
.
gfile
.
exists
(
vocab_file
):
tf
.
compat
.
v1
.
logging
.
info
(
"Vocab file already exists (%s)"
%
vocab_file
)
logging
.
info
(
"Vocab file already exists (%s)"
,
vocab_file
)
else
:
else
:
tf
.
compat
.
v1
.
logging
.
info
(
"Begin steps to create subtoken vocabulary..."
)
logging
.
info
(
"Begin steps to create subtoken vocabulary..."
)
token_counts
=
_count_tokens
(
files
,
file_byte_limit
,
correct_strip
,
token_counts
=
_count_tokens
(
files
,
file_byte_limit
,
correct_strip
,
master_char_set
)
master_char_set
)
alphabet
=
_generate_alphabet_dict
(
token_counts
)
alphabet
=
_generate_alphabet_dict
(
token_counts
)
subtoken_list
=
_generate_subtokens_with_target_vocab_size
(
subtoken_list
=
_generate_subtokens_with_target_vocab_size
(
token_counts
,
alphabet
,
target_vocab_size
,
threshold
,
min_count
,
token_counts
,
alphabet
,
target_vocab_size
,
threshold
,
min_count
,
reserved_tokens
)
reserved_tokens
)
tf
.
compat
.
v1
.
logging
.
info
(
"Generated vocabulary with %d subtokens."
%
logging
.
info
(
"Generated vocabulary with %d subtokens."
,
len
(
subtoken_list
))
len
(
subtoken_list
))
_save_vocab_file
(
vocab_file
,
subtoken_list
)
_save_vocab_file
(
vocab_file
,
subtoken_list
)
return
Subtokenizer
(
vocab_file
,
master_char_set
=
master_char_set
)
return
Subtokenizer
(
vocab_file
,
master_char_set
=
master_char_set
)
...
@@ -439,23 +439,22 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
...
@@ -439,23 +439,22 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
if
min_count
is
not
None
:
if
min_count
is
not
None
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Using min_count=%d to generate vocab with target size %d"
,
"Using min_count=%d to generate vocab with target size %d"
%
min_count
,
target_size
)
(
min_count
,
target_size
))
return
_generate_subtokens
(
return
_generate_subtokens
(
token_counts
,
alphabet
,
min_count
,
reserved_tokens
=
reserved_tokens
)
token_counts
,
alphabet
,
min_count
,
reserved_tokens
=
reserved_tokens
)
def
bisect
(
min_val
,
max_val
):
def
bisect
(
min_val
,
max_val
):
"""Recursive function to binary search for subtoken vocabulary."""
"""Recursive function to binary search for subtoken vocabulary."""
cur_count
=
(
min_val
+
max_val
)
//
2
cur_count
=
(
min_val
+
max_val
)
//
2
tf
.
compat
.
v1
.
logging
.
info
(
"Binary search: trying min_count=%d (%d %d)"
%
logging
.
info
(
"Binary search: trying min_count=%d (%d %d)"
,
cur_count
,
(
cur_count
,
min_val
,
max_val
)
)
min_val
,
max_val
)
subtoken_list
=
_generate_subtokens
(
subtoken_list
=
_generate_subtokens
(
token_counts
,
alphabet
,
cur_count
,
reserved_tokens
=
reserved_tokens
)
token_counts
,
alphabet
,
cur_count
,
reserved_tokens
=
reserved_tokens
)
val
=
len
(
subtoken_list
)
val
=
len
(
subtoken_list
)
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Binary search: min_count=%d resulted in %d tokens"
,
cur_count
,
"Binary search: min_count=%d resulted in %d tokens"
%
(
cur_count
,
val
)
)
val
)
within_threshold
=
abs
(
val
-
target_size
)
<
threshold
within_threshold
=
abs
(
val
-
target_size
)
<
threshold
if
within_threshold
or
min_val
>=
max_val
or
cur_count
<
2
:
if
within_threshold
or
min_val
>=
max_val
or
cur_count
<
2
:
...
@@ -471,8 +470,7 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
...
@@ -471,8 +470,7 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
return
other_subtoken_list
return
other_subtoken_list
return
subtoken_list
return
subtoken_list
tf
.
compat
.
v1
.
logging
.
info
(
"Finding best min_count to get target size of %d"
%
logging
.
info
(
"Finding best min_count to get target size of %d"
,
target_size
)
target_size
)
return
bisect
(
_MIN_MIN_COUNT
,
_MAX_MIN_COUNT
)
return
bisect
(
_MIN_MIN_COUNT
,
_MAX_MIN_COUNT
)
...
@@ -644,7 +642,7 @@ def _generate_subtokens(token_counts,
...
@@ -644,7 +642,7 @@ def _generate_subtokens(token_counts,
# subtoken_dict, count how often the resulting subtokens appear, and update
# subtoken_dict, count how often the resulting subtokens appear, and update
# the dictionary with subtokens w/ high enough counts.
# the dictionary with subtokens w/ high enough counts.
for
i
in
xrange
(
num_iterations
):
for
i
in
xrange
(
num_iterations
):
tf
.
compat
.
v1
.
logging
.
info
(
"
\t
Generating subtokens: iteration %d"
%
i
)
logging
.
info
(
"
\t
Generating subtokens: iteration %d"
,
i
)
# Generate new subtoken->id dictionary using the new subtoken list.
# Generate new subtoken->id dictionary using the new subtoken list.
subtoken_dict
=
_list_to_index_dict
(
subtoken_list
)
subtoken_dict
=
_list_to_index_dict
(
subtoken_list
)
...
@@ -658,5 +656,5 @@ def _generate_subtokens(token_counts,
...
@@ -658,5 +656,5 @@ def _generate_subtokens(token_counts,
subtoken_list
,
max_subtoken_length
=
_gen_new_subtoken_list
(
subtoken_list
,
max_subtoken_length
=
_gen_new_subtoken_list
(
subtoken_counts
,
min_count
,
alphabet
,
reserved_tokens
)
subtoken_counts
,
min_count
,
alphabet
,
reserved_tokens
)
tf
.
compat
.
v1
.
logging
.
info
(
"
\t
Vocab size: %d"
%
len
(
subtoken_list
))
logging
.
info
(
"
\t
Vocab size: %d"
,
len
(
subtoken_list
))
return
subtoken_list
return
subtoken_list
official/r1/mnist/mnist.py
View file @
bd488858
...
@@ -19,8 +19,9 @@ from __future__ import print_function
...
@@ -19,8 +19,9 @@ from __future__ import print_function
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.r1.mnist
import
dataset
from
official.r1.mnist
import
dataset
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
...
@@ -241,6 +242,6 @@ def main(_):
...
@@ -241,6 +242,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_mnist_flags
()
define_mnist_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/r1/mnist/mnist_test.py
View file @
bd488858
...
@@ -21,7 +21,7 @@ import time
...
@@ -21,7 +21,7 @@ import time
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.mnist
import
mnist
from
official.r1.mnist
import
mnist
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -143,5 +143,5 @@ class Benchmarks(tf.test.Benchmark):
...
@@ -143,5 +143,5 @@ class Benchmarks(tf.test.Benchmark):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
tf
.
test
.
main
()
tf
.
test
.
main
()
official/r1/resnet/cifar10_main.py
View file @
bd488858
...
@@ -22,8 +22,9 @@ import os
...
@@ -22,8 +22,9 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.r1.resnet
import
resnet_model
from
official.r1.resnet
import
resnet_model
from
official.r1.resnet
import
resnet_run_loop
from
official.r1.resnet
import
resnet_run_loop
...
@@ -139,9 +140,9 @@ def input_fn(is_training,
...
@@ -139,9 +140,9 @@ def input_fn(is_training,
dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filenames
,
_RECORD_BYTES
)
dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filenames
,
_RECORD_BYTES
)
if
input_context
:
if
input_context
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
%
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
,
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
)
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
...
@@ -270,7 +271,7 @@ def run_cifar(flags_obj):
...
@@ -270,7 +271,7 @@ def run_cifar(flags_obj):
Dictionary of results. Including final accuracy.
Dictionary of results. Including final accuracy.
"""
"""
if
flags_obj
.
image_bytes_as_serving_input
:
if
flags_obj
.
image_bytes_as_serving_input
:
tf
.
compat
.
v1
.
logging
.
fatal
(
logging
.
fatal
(
'--image_bytes_as_serving_input cannot be set to True for CIFAR. '
'--image_bytes_as_serving_input cannot be set to True for CIFAR. '
'This flag is only applicable to ImageNet.'
)
'This flag is only applicable to ImageNet.'
)
return
return
...
@@ -291,6 +292,6 @@ def main(_):
...
@@ -291,6 +292,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_cifar_flags
()
define_cifar_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/r1/resnet/cifar10_test.py
View file @
bd488858
...
@@ -19,14 +19,15 @@ from __future__ import print_function
...
@@ -19,14 +19,15 @@ from __future__ import print_function
from
tempfile
import
mkstemp
from
tempfile
import
mkstemp
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.r1.resnet
import
cifar10_main
from
official.r1.resnet
import
cifar10_main
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
_BATCH_SIZE
=
128
_BATCH_SIZE
=
128
_HEIGHT
=
32
_HEIGHT
=
32
...
...
official/r1/resnet/estimator_benchmark.py
View file @
bd488858
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
import
time
import
time
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
...
@@ -56,7 +57,7 @@ class EstimatorBenchmark(tf.test.Benchmark):
...
@@ -56,7 +57,7 @@ class EstimatorBenchmark(tf.test.Benchmark):
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
EstimatorBenchmark
.
local_flags
is
None
:
if
EstimatorBenchmark
.
local_flags
is
None
:
for
flag_method
in
self
.
flag_methods
:
for
flag_method
in
self
.
flag_methods
:
flag_method
()
flag_method
()
...
...
official/r1/resnet/imagenet_main.py
View file @
bd488858
...
@@ -22,6 +22,7 @@ import os
...
@@ -22,6 +22,7 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -194,9 +195,9 @@ def input_fn(is_training,
...
@@ -194,9 +195,9 @@ def input_fn(is_training,
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
if
input_context
:
if
input_context
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
%
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
,
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
)
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
...
@@ -387,6 +388,6 @@ def main(_):
...
@@ -387,6 +388,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_imagenet_flags
()
define_imagenet_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/r1/resnet/imagenet_test.py
View file @
bd488858
...
@@ -20,12 +20,13 @@ from __future__ import print_function
...
@@ -20,12 +20,13 @@ from __future__ import print_function
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.resnet
import
imagenet_main
from
official.r1.resnet
import
imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
_BATCH_SIZE
=
32
_BATCH_SIZE
=
32
_LABEL_CLASSES
=
1001
_LABEL_CLASSES
=
1001
...
...
official/r1/resnet/resnet_run_loop.py
View file @
bd488858
...
@@ -29,6 +29,7 @@ import multiprocessing
...
@@ -29,6 +29,7 @@ import multiprocessing
import
os
import
os
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.r1.resnet
import
imagenet_preprocessing
from
official.r1.resnet
import
imagenet_preprocessing
...
@@ -83,8 +84,8 @@ def process_record_dataset(dataset,
...
@@ -83,8 +84,8 @@ def process_record_dataset(dataset,
options
.
experimental_threading
.
private_threadpool_size
=
(
options
.
experimental_threading
.
private_threadpool_size
=
(
datasets_num_private_threads
)
datasets_num_private_threads
)
dataset
=
dataset
.
with_options
(
options
)
dataset
=
dataset
.
with_options
(
options
)
tf
.
compat
.
v1
.
logging
.
info
(
'datasets_num_private_threads: %s'
,
logging
.
info
(
'datasets_num_private_threads: %s'
,
datasets_num_private_threads
)
datasets_num_private_threads
)
# Disable intra-op parallelism to optimize for throughput instead of latency.
# Disable intra-op parallelism to optimize for throughput instead of latency.
options
=
tf
.
data
.
Options
()
options
=
tf
.
data
.
Options
()
...
@@ -205,17 +206,15 @@ def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj):
...
@@ -205,17 +206,15 @@ def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj):
what has been set by the user on the command-line.
what has been set by the user on the command-line.
"""
"""
cpu_count
=
multiprocessing
.
cpu_count
()
cpu_count
=
multiprocessing
.
cpu_count
()
tf
.
compat
.
v1
.
logging
.
info
(
'Logical CPU cores: %s'
,
cpu_count
)
logging
.
info
(
'Logical CPU cores: %s'
,
cpu_count
)
# Sets up thread pool for each GPU for op scheduling.
# Sets up thread pool for each GPU for op scheduling.
per_gpu_thread_count
=
1
per_gpu_thread_count
=
1
total_gpu_thread_count
=
per_gpu_thread_count
*
flags_obj
.
num_gpus
total_gpu_thread_count
=
per_gpu_thread_count
*
flags_obj
.
num_gpus
os
.
environ
[
'TF_GPU_THREAD_MODE'
]
=
flags_obj
.
tf_gpu_thread_mode
os
.
environ
[
'TF_GPU_THREAD_MODE'
]
=
flags_obj
.
tf_gpu_thread_mode
os
.
environ
[
'TF_GPU_THREAD_COUNT'
]
=
str
(
per_gpu_thread_count
)
os
.
environ
[
'TF_GPU_THREAD_COUNT'
]
=
str
(
per_gpu_thread_count
)
tf
.
compat
.
v1
.
logging
.
info
(
'TF_GPU_THREAD_COUNT: %s'
,
logging
.
info
(
'TF_GPU_THREAD_COUNT: %s'
,
os
.
environ
[
'TF_GPU_THREAD_COUNT'
])
os
.
environ
[
'TF_GPU_THREAD_COUNT'
])
logging
.
info
(
'TF_GPU_THREAD_MODE: %s'
,
os
.
environ
[
'TF_GPU_THREAD_MODE'
])
tf
.
compat
.
v1
.
logging
.
info
(
'TF_GPU_THREAD_MODE: %s'
,
os
.
environ
[
'TF_GPU_THREAD_MODE'
])
# Reduces general thread pool by number of threads used for GPU pool.
# Reduces general thread pool by number of threads used for GPU pool.
main_thread_count
=
cpu_count
-
total_gpu_thread_count
main_thread_count
=
cpu_count
-
total_gpu_thread_count
...
@@ -648,7 +647,7 @@ def resnet_main(
...
@@ -648,7 +647,7 @@ def resnet_main(
hooks
=
train_hooks
,
hooks
=
train_hooks
,
max_steps
=
flags_obj
.
max_train_steps
)
max_steps
=
flags_obj
.
max_train_steps
)
eval_spec
=
tf
.
estimator
.
EvalSpec
(
input_fn
=
input_fn_eval
)
eval_spec
=
tf
.
estimator
.
EvalSpec
(
input_fn
=
input_fn_eval
)
tf
.
compat
.
v1
.
logging
.
info
(
'Starting to train and evaluate.'
)
logging
.
info
(
'Starting to train and evaluate.'
)
tf
.
estimator
.
train_and_evaluate
(
classifier
,
train_spec
,
eval_spec
)
tf
.
estimator
.
train_and_evaluate
(
classifier
,
train_spec
,
eval_spec
)
# tf.estimator.train_and_evalute doesn't return anything in multi-worker
# tf.estimator.train_and_evalute doesn't return anything in multi-worker
# case.
# case.
...
@@ -671,8 +670,7 @@ def resnet_main(
...
@@ -671,8 +670,7 @@ def resnet_main(
schedule
[
-
1
]
=
train_epochs
-
sum
(
schedule
[:
-
1
])
# over counting.
schedule
[
-
1
]
=
train_epochs
-
sum
(
schedule
[:
-
1
])
# over counting.
for
cycle_index
,
num_train_epochs
in
enumerate
(
schedule
):
for
cycle_index
,
num_train_epochs
in
enumerate
(
schedule
):
tf
.
compat
.
v1
.
logging
.
info
(
'Starting cycle: %d/%d'
,
cycle_index
,
logging
.
info
(
'Starting cycle: %d/%d'
,
cycle_index
,
int
(
n_loops
))
int
(
n_loops
))
if
num_train_epochs
:
if
num_train_epochs
:
# Since we are calling classifier.train immediately in each loop, the
# Since we are calling classifier.train immediately in each loop, the
...
@@ -691,7 +689,7 @@ def resnet_main(
...
@@ -691,7 +689,7 @@ def resnet_main(
# allows the eval (which is generally unimportant in those circumstances)
# allows the eval (which is generally unimportant in those circumstances)
# to terminate. Note that eval will run for max_train_steps each loop,
# to terminate. Note that eval will run for max_train_steps each loop,
# regardless of the global_step count.
# regardless of the global_step count.
tf
.
compat
.
v1
.
logging
.
info
(
'Starting to evaluate.'
)
logging
.
info
(
'Starting to evaluate.'
)
eval_results
=
classifier
.
evaluate
(
input_fn
=
input_fn_eval
,
eval_results
=
classifier
.
evaluate
(
input_fn
=
input_fn_eval
,
steps
=
flags_obj
.
max_train_steps
)
steps
=
flags_obj
.
max_train_steps
)
...
...
official/r1/utils/data/file_io.py
View file @
bd488858
...
@@ -25,10 +25,11 @@ import os
...
@@ -25,10 +25,11 @@ import os
import
tempfile
import
tempfile
import
uuid
import
uuid
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint:disable=logging-format-interpolation
class
_GarbageCollector
(
object
):
class
_GarbageCollector
(
object
):
...
@@ -50,9 +51,9 @@ class _GarbageCollector(object):
...
@@ -50,9 +51,9 @@ class _GarbageCollector(object):
for
i
in
self
.
temp_buffers
:
for
i
in
self
.
temp_buffers
:
if
tf
.
io
.
gfile
.
exists
(
i
):
if
tf
.
io
.
gfile
.
exists
(
i
):
tf
.
io
.
gfile
.
remove
(
i
)
tf
.
io
.
gfile
.
remove
(
i
)
tf
.
compat
.
v1
.
logging
.
info
(
"Buffer file {} removed"
.
format
(
i
))
logging
.
info
(
"Buffer file {} removed"
.
format
(
i
))
except
Exception
as
e
:
except
Exception
as
e
:
tf
.
compat
.
v1
.
logging
.
error
(
"Failed to cleanup buffer files: {}"
.
format
(
e
))
logging
.
error
(
"Failed to cleanup buffer files: {}"
.
format
(
e
))
_GARBAGE_COLLECTOR
=
_GarbageCollector
()
_GARBAGE_COLLECTOR
=
_GarbageCollector
()
...
@@ -176,7 +177,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
...
@@ -176,7 +177,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
actual_size
=
tf
.
io
.
gfile
.
stat
(
buffer_path
).
length
actual_size
=
tf
.
io
.
gfile
.
stat
(
buffer_path
).
length
if
expected_size
==
actual_size
:
if
expected_size
==
actual_size
:
return
buffer_path
return
buffer_path
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"Existing buffer {} has size {}. Expected size {}. Deleting and "
"Existing buffer {} has size {}. Expected size {}. Deleting and "
"rebuilding buffer."
.
format
(
buffer_path
,
actual_size
,
expected_size
))
"rebuilding buffer."
.
format
(
buffer_path
,
actual_size
,
expected_size
))
tf
.
io
.
gfile
.
remove
(
buffer_path
)
tf
.
io
.
gfile
.
remove
(
buffer_path
)
...
@@ -187,8 +188,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
...
@@ -187,8 +188,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
split
(
buffer_path
)[
0
])
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
split
(
buffer_path
)[
0
])
tf
.
compat
.
v1
.
logging
.
info
(
"Constructing TFRecordDataset buffer: {}"
logging
.
info
(
"Constructing TFRecordDataset buffer: {}"
.
format
(
buffer_path
))
.
format
(
buffer_path
))
count
=
0
count
=
0
pool
=
multiprocessing
.
dummy
.
Pool
(
multiprocessing
.
cpu_count
())
pool
=
multiprocessing
.
dummy
.
Pool
(
multiprocessing
.
cpu_count
())
...
@@ -198,10 +198,10 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
...
@@ -198,10 +198,10 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
rows_per_core
=
_ROWS_PER_CORE
):
rows_per_core
=
_ROWS_PER_CORE
):
_serialize_shards
(
df_shards
,
columns
,
pool
,
writer
)
_serialize_shards
(
df_shards
,
columns
,
pool
,
writer
)
count
+=
sum
([
len
(
s
)
for
s
in
df_shards
])
count
+=
sum
([
len
(
s
)
for
s
in
df_shards
])
tf
.
compat
.
v1
.
logging
.
info
(
"{}/{} examples written."
logging
.
info
(
"{}/{} examples written."
.
format
(
.
format
(
str
(
count
).
ljust
(
8
),
len
(
dataframe
)))
str
(
count
).
ljust
(
8
),
len
(
dataframe
)))
finally
:
finally
:
pool
.
terminate
()
pool
.
terminate
()
tf
.
compat
.
v1
.
logging
.
info
(
"Buffer write complete."
)
logging
.
info
(
"Buffer write complete."
)
return
buffer_path
return
buffer_path
official/r1/wide_deep/census_test.py
View file @
bd488858
...
@@ -21,13 +21,14 @@ import os
...
@@ -21,13 +21,14 @@ import os
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
from
official.r1.wide_deep
import
census_dataset
from
official.r1.wide_deep
import
census_dataset
from
official.r1.wide_deep
import
census_main
from
official.r1.wide_deep
import
census_main
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
TEST_INPUT
=
(
'18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
TEST_INPUT
=
(
'18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
'Husband,zyx,wvu,34,56,78,tsr,<=50K'
)
'Husband,zyx,wvu,34,56,78,tsr,<=50K'
)
...
...
official/r1/wide_deep/movielens_test.py
View file @
bd488858
...
@@ -28,8 +28,9 @@ from official.utils.misc import keras_utils
...
@@ -28,8 +28,9 @@ from official.utils.misc import keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
from
official.r1.wide_deep
import
movielens_dataset
from
official.r1.wide_deep
import
movielens_dataset
from
official.r1.wide_deep
import
movielens_main
from
official.r1.wide_deep
import
movielens_main
from
absl
import
logging
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
TEST_INPUT_VALUES
=
{
TEST_INPUT_VALUES
=
{
...
...
official/utils/flags/_device.py
View file @
bd488858
...
@@ -19,7 +19,7 @@ from __future__ import division
...
@@ -19,7 +19,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
from
absl
import
logging
from
official.utils.flags._conventions
import
help_wrap
from
official.utils.flags._conventions
import
help_wrap
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
valid_flags
=
True
valid_flags
=
True
for
key
in
flag_names
:
for
key
in
flag_names
:
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
tf
.
compat
.
v1
.
logging
.
error
(
"
{}
must be a GCS path."
.
format
(
key
)
)
logging
.
error
(
"
%s
must be a GCS path."
,
key
)
valid_flags
=
False
valid_flags
=
False
return
valid_flags
return
valid_flags
...
...
official/utils/logs/hooks_helper.py
View file @
bd488858
...
@@ -25,6 +25,7 @@ from __future__ import division
...
@@ -25,6 +25,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.utils.logs
import
hooks
from
official.utils.logs
import
hooks
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
...
@@ -57,9 +58,9 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
...
@@ -57,9 +58,9 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return
[]
return
[]
if
use_tpu
:
if
use_tpu
:
tf
.
compat
.
v1
.
logging
.
warning
(
'hooks_helper received name_list `{}`, but a '
logging
.
warning
(
'TPU is specified. No hooks will be used.
'
'hooks_helper received name_list `%s`, but a
'
.
format
(
name_list
)
)
'TPU is specified. No hooks will be used.'
,
name_list
)
return
[]
return
[]
train_hooks
=
[]
train_hooks
=
[]
...
...
official/utils/logs/hooks_test.py
View file @
bd488858
...
@@ -21,12 +21,13 @@ from __future__ import print_function
...
@@ -21,12 +21,13 @@ from __future__ import print_function
import
time
import
time
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.logs
import
hooks
from
official.utils.logs
import
hooks
from
official.utils.testing
import
mock_lib
from
official.utils.testing
import
mock_lib
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
logging
.
set_verbosity
(
logging
.
DEBUG
)
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
...
...
official/utils/logs/logger.py
View file @
bd488858
...
@@ -35,6 +35,7 @@ from six.moves import _thread as thread
...
@@ -35,6 +35,7 @@ from six.moves import _thread as thread
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.client
import
device_lib
from
tensorflow.python.client
import
device_lib
from
absl
import
logging
from
official.utils.logs
import
cloud_lib
from
official.utils.logs
import
cloud_lib
...
@@ -119,9 +120,8 @@ class BaseBenchmarkLogger(object):
...
@@ -119,9 +120,8 @@ class BaseBenchmarkLogger(object):
eval_results: dict, the result of evaluate.
eval_results: dict, the result of evaluate.
"""
"""
if
not
isinstance
(
eval_results
,
dict
):
if
not
isinstance
(
eval_results
,
dict
):
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"eval_results should be dictionary for logging. Got %s"
,
"eval_results should be dictionary for logging. Got %s"
,
type
(
eval_results
))
type
(
eval_results
))
return
return
global_step
=
eval_results
[
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
]
global_step
=
eval_results
[
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
]
for
key
in
sorted
(
eval_results
):
for
key
in
sorted
(
eval_results
):
...
@@ -144,12 +144,12 @@ class BaseBenchmarkLogger(object):
...
@@ -144,12 +144,12 @@ class BaseBenchmarkLogger(object):
"""
"""
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
if
metric
:
if
metric
:
tf
.
compat
.
v1
.
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
"Benchmark run: %s"
,
run_params
,
test_id
))
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
))
def
on_finish
(
self
,
status
):
def
on_finish
(
self
,
status
):
pass
pass
...
@@ -187,7 +187,7 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -187,7 +187,7 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
self
.
_metric_file_handler
.
write
(
"
\n
"
)
self
.
_metric_file_handler
.
write
(
"
\n
"
)
self
.
_metric_file_handler
.
flush
()
self
.
_metric_file_handler
.
flush
()
except
(
TypeError
,
ValueError
)
as
e
:
except
(
TypeError
,
ValueError
)
as
e
:
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"Failed to dump metric to log file: name %s, value %s, error %s"
,
"Failed to dump metric to log file: name %s, value %s, error %s"
,
name
,
value
,
e
)
name
,
value
,
e
)
...
@@ -212,8 +212,7 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -212,8 +212,7 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
json
.
dump
(
run_info
,
f
)
json
.
dump
(
run_info
,
f
)
f
.
write
(
"
\n
"
)
f
.
write
(
"
\n
"
)
except
(
TypeError
,
ValueError
)
as
e
:
except
(
TypeError
,
ValueError
)
as
e
:
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"Failed to dump benchmark run info to log file: %s"
,
e
)
"Failed to dump benchmark run info to log file: %s"
,
e
)
def
on_finish
(
self
,
status
):
def
on_finish
(
self
,
status
):
self
.
_metric_file_handler
.
flush
()
self
.
_metric_file_handler
.
flush
()
...
@@ -322,8 +321,8 @@ def _process_metric_to_json(
...
@@ -322,8 +321,8 @@ def _process_metric_to_json(
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
"""Validate the metric data and generate JSON for insert."""
"""Validate the metric data and generate JSON for insert."""
if
not
isinstance
(
value
,
numbers
.
Number
):
if
not
isinstance
(
value
,
numbers
.
Number
):
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"Metric value to log should be a number. Got %s"
,
"Metric value to log should be a number. Got %s"
,
type
(
value
))
type
(
value
))
return
None
return
None
extras
=
_convert_to_json_dict
(
extras
)
extras
=
_convert_to_json_dict
(
extras
)
...
@@ -383,8 +382,7 @@ def _collect_cpu_info(run_info):
...
@@ -383,8 +382,7 @@ def _collect_cpu_info(run_info):
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
except
ImportError
:
except
ImportError
:
tf
.
compat
.
v1
.
logging
.
warn
(
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
"'cpuinfo' not imported. CPU info will not be logged."
)
def
_collect_memory_info
(
run_info
):
def
_collect_memory_info
(
run_info
):
...
@@ -396,8 +394,7 @@ def _collect_memory_info(run_info):
...
@@ -396,8 +394,7 @@ def _collect_memory_info(run_info):
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
except
ImportError
:
except
ImportError
:
tf
.
compat
.
v1
.
logging
.
warn
(
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
"'psutil' not imported. Memory info will not be logged."
)
def
_collect_test_environment
(
run_info
):
def
_collect_test_environment
(
run_info
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment