Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f3247f4
Unverified
Commit
1f3247f4
authored
Mar 27, 2020
by
Ayushman Kumar
Committed by
GitHub
Mar 27, 2020
Browse files
Merge pull request #6 from tensorflow/master
Updated
parents
370a4c8d
0265f59c
Changes
85
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
140 additions
and
82 deletions
+140
-82
official/nlp/transformer/translate.py
official/nlp/transformer/translate.py
+7
-8
official/nlp/transformer/utils/tokenizer.py
official/nlp/transformer/utils/tokenizer.py
+15
-17
official/nlp/xlnet/training_utils.py
official/nlp/xlnet/training_utils.py
+1
-1
official/r1/mnist/mnist.py
official/r1/mnist/mnist.py
+5
-4
official/r1/mnist/mnist_test.py
official/r1/mnist/mnist_test.py
+2
-2
official/r1/resnet/cifar10_main.py
official/r1/resnet/cifar10_main.py
+7
-6
official/r1/resnet/cifar10_test.py
official/r1/resnet/cifar10_test.py
+3
-2
official/r1/resnet/estimator_benchmark.py
official/r1/resnet/estimator_benchmark.py
+2
-1
official/r1/resnet/imagenet_main.py
official/r1/resnet/imagenet_main.py
+5
-4
official/r1/resnet/imagenet_test.py
official/r1/resnet/imagenet_test.py
+2
-1
official/r1/resnet/resnet_model.py
official/r1/resnet/resnet_model.py
+2
-2
official/r1/resnet/resnet_run_loop.py
official/r1/resnet/resnet_run_loop.py
+9
-11
official/r1/utils/data/file_io.py
official/r1/utils/data/file_io.py
+9
-9
official/r1/wide_deep/census_test.py
official/r1/wide_deep/census_test.py
+2
-1
official/r1/wide_deep/movielens_test.py
official/r1/wide_deep/movielens_test.py
+2
-1
official/requirements.txt
official/requirements.txt
+2
-1
official/staging/training/controller.py
official/staging/training/controller.py
+14
-7
official/staging/training/controller_test.py
official/staging/training/controller_test.py
+46
-0
official/staging/training/grad_utils.py
official/staging/training/grad_utils.py
+3
-2
official/utils/flags/_device.py
official/utils/flags/_device.py
+2
-2
No files found.
official/nlp/transformer/translate.py
View file @
1f3247f4
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -117,8 +118,7 @@ def translate_file(model,
...
@@ -117,8 +118,7 @@ def translate_file(model,
maxlen
=
params
[
"decode_max_length"
],
maxlen
=
params
[
"decode_max_length"
],
dtype
=
"int32"
,
dtype
=
"int32"
,
padding
=
"post"
)
padding
=
"post"
)
tf
.
compat
.
v1
.
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
logging
.
info
(
"Decoding batch %d out of %d."
,
i
,
num_decode_batches
)
num_decode_batches
)
yield
batch
yield
batch
@
tf
.
function
@
tf
.
function
...
@@ -172,16 +172,15 @@ def translate_file(model,
...
@@ -172,16 +172,15 @@ def translate_file(model,
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
translation
=
_trim_and_decode
(
val_outputs
[
j
],
subtokenizer
)
translations
.
append
(
translation
)
translations
.
append
(
translation
)
if
print_all_translations
:
if
print_all_translations
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
,
"Translating:
\n\t
Input: %s
\n\t
Output: %s"
%
sorted_inputs
[
j
+
i
*
batch_size
],
translation
)
(
sorted_inputs
[
j
+
i
*
batch_size
],
translation
))
# Write translations in the order they appeared in the original file.
# Write translations in the order they appeared in the original file.
if
output_file
is
not
None
:
if
output_file
is
not
None
:
if
tf
.
io
.
gfile
.
isdir
(
output_file
):
if
tf
.
io
.
gfile
.
isdir
(
output_file
):
raise
ValueError
(
"File output is a directory, will not save outputs to "
raise
ValueError
(
"File output is a directory, will not save outputs to "
"file."
)
"file."
)
tf
.
compat
.
v1
.
logging
.
info
(
"Writing to file %s"
%
output_file
)
logging
.
info
(
"Writing to file %s"
,
output_file
)
with
tf
.
compat
.
v1
.
gfile
.
Open
(
output_file
,
"w"
)
as
f
:
with
tf
.
compat
.
v1
.
gfile
.
Open
(
output_file
,
"w"
)
as
f
:
for
i
in
sorted_keys
:
for
i
in
sorted_keys
:
f
.
write
(
"%s
\n
"
%
translations
[
i
])
f
.
write
(
"%s
\n
"
%
translations
[
i
])
...
@@ -191,10 +190,10 @@ def translate_from_text(model, subtokenizer, txt):
...
@@ -191,10 +190,10 @@ def translate_from_text(model, subtokenizer, txt):
encoded_txt
=
_encode_and_add_eos
(
txt
,
subtokenizer
)
encoded_txt
=
_encode_and_add_eos
(
txt
,
subtokenizer
)
result
=
model
.
predict
(
encoded_txt
)
result
=
model
.
predict
(
encoded_txt
)
outputs
=
result
[
"outputs"
]
outputs
=
result
[
"outputs"
]
tf
.
compat
.
v1
.
logging
.
info
(
"Original:
\"
%s
\"
"
%
txt
)
logging
.
info
(
"Original:
\"
%s
\"
"
,
txt
)
translate_from_input
(
outputs
,
subtokenizer
)
translate_from_input
(
outputs
,
subtokenizer
)
def
translate_from_input
(
outputs
,
subtokenizer
):
def
translate_from_input
(
outputs
,
subtokenizer
):
translation
=
_trim_and_decode
(
outputs
,
subtokenizer
)
translation
=
_trim_and_decode
(
outputs
,
subtokenizer
)
tf
.
compat
.
v1
.
logging
.
info
(
"Translation:
\"
%s
\"
"
%
translation
)
logging
.
info
(
"Translation:
\"
%s
\"
"
,
translation
)
official/nlp/transformer/utils/tokenizer.py
View file @
1f3247f4
...
@@ -22,6 +22,7 @@ import collections
...
@@ -22,6 +22,7 @@ import collections
import
re
import
re
import
sys
import
sys
import
unicodedata
import
unicodedata
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
six
import
six
...
@@ -71,8 +72,7 @@ class Subtokenizer(object):
...
@@ -71,8 +72,7 @@ class Subtokenizer(object):
def
__init__
(
self
,
vocab_file
,
reserved_tokens
=
None
,
master_char_set
=
None
):
def
__init__
(
self
,
vocab_file
,
reserved_tokens
=
None
,
master_char_set
=
None
):
"""Initializes class, creating a vocab file if data_files is provided."""
"""Initializes class, creating a vocab file if data_files is provided."""
tf
.
compat
.
v1
.
logging
.
info
(
"Initializing Subtokenizer from file %s."
%
logging
.
info
(
"Initializing Subtokenizer from file %s."
,
vocab_file
)
vocab_file
)
if
master_char_set
is
None
:
if
master_char_set
is
None
:
master_char_set
=
_ALPHANUMERIC_CHAR_SET
master_char_set
=
_ALPHANUMERIC_CHAR_SET
...
@@ -130,17 +130,17 @@ class Subtokenizer(object):
...
@@ -130,17 +130,17 @@ class Subtokenizer(object):
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
if
tf
.
io
.
gfile
.
exists
(
vocab_file
):
if
tf
.
io
.
gfile
.
exists
(
vocab_file
):
tf
.
compat
.
v1
.
logging
.
info
(
"Vocab file already exists (%s)"
%
vocab_file
)
logging
.
info
(
"Vocab file already exists (%s)"
,
vocab_file
)
else
:
else
:
tf
.
compat
.
v1
.
logging
.
info
(
"Begin steps to create subtoken vocabulary..."
)
logging
.
info
(
"Begin steps to create subtoken vocabulary..."
)
token_counts
=
_count_tokens
(
files
,
file_byte_limit
,
correct_strip
,
token_counts
=
_count_tokens
(
files
,
file_byte_limit
,
correct_strip
,
master_char_set
)
master_char_set
)
alphabet
=
_generate_alphabet_dict
(
token_counts
)
alphabet
=
_generate_alphabet_dict
(
token_counts
)
subtoken_list
=
_generate_subtokens_with_target_vocab_size
(
subtoken_list
=
_generate_subtokens_with_target_vocab_size
(
token_counts
,
alphabet
,
target_vocab_size
,
threshold
,
min_count
,
token_counts
,
alphabet
,
target_vocab_size
,
threshold
,
min_count
,
reserved_tokens
)
reserved_tokens
)
tf
.
compat
.
v1
.
logging
.
info
(
"Generated vocabulary with %d subtokens."
%
logging
.
info
(
"Generated vocabulary with %d subtokens."
,
len
(
subtoken_list
))
len
(
subtoken_list
))
_save_vocab_file
(
vocab_file
,
subtoken_list
)
_save_vocab_file
(
vocab_file
,
subtoken_list
)
return
Subtokenizer
(
vocab_file
,
master_char_set
=
master_char_set
)
return
Subtokenizer
(
vocab_file
,
master_char_set
=
master_char_set
)
...
@@ -439,23 +439,22 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
...
@@ -439,23 +439,22 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
reserved_tokens
=
RESERVED_TOKENS
reserved_tokens
=
RESERVED_TOKENS
if
min_count
is
not
None
:
if
min_count
is
not
None
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Using min_count=%d to generate vocab with target size %d"
,
"Using min_count=%d to generate vocab with target size %d"
%
min_count
,
target_size
)
(
min_count
,
target_size
))
return
_generate_subtokens
(
return
_generate_subtokens
(
token_counts
,
alphabet
,
min_count
,
reserved_tokens
=
reserved_tokens
)
token_counts
,
alphabet
,
min_count
,
reserved_tokens
=
reserved_tokens
)
def
bisect
(
min_val
,
max_val
):
def
bisect
(
min_val
,
max_val
):
"""Recursive function to binary search for subtoken vocabulary."""
"""Recursive function to binary search for subtoken vocabulary."""
cur_count
=
(
min_val
+
max_val
)
//
2
cur_count
=
(
min_val
+
max_val
)
//
2
tf
.
compat
.
v1
.
logging
.
info
(
"Binary search: trying min_count=%d (%d %d)"
%
logging
.
info
(
"Binary search: trying min_count=%d (%d %d)"
,
cur_count
,
(
cur_count
,
min_val
,
max_val
)
)
min_val
,
max_val
)
subtoken_list
=
_generate_subtokens
(
subtoken_list
=
_generate_subtokens
(
token_counts
,
alphabet
,
cur_count
,
reserved_tokens
=
reserved_tokens
)
token_counts
,
alphabet
,
cur_count
,
reserved_tokens
=
reserved_tokens
)
val
=
len
(
subtoken_list
)
val
=
len
(
subtoken_list
)
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
"Binary search: min_count=%d resulted in %d tokens"
,
cur_count
,
"Binary search: min_count=%d resulted in %d tokens"
%
(
cur_count
,
val
)
)
val
)
within_threshold
=
abs
(
val
-
target_size
)
<
threshold
within_threshold
=
abs
(
val
-
target_size
)
<
threshold
if
within_threshold
or
min_val
>=
max_val
or
cur_count
<
2
:
if
within_threshold
or
min_val
>=
max_val
or
cur_count
<
2
:
...
@@ -471,8 +470,7 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
...
@@ -471,8 +470,7 @@ def _generate_subtokens_with_target_vocab_size(token_counts,
return
other_subtoken_list
return
other_subtoken_list
return
subtoken_list
return
subtoken_list
tf
.
compat
.
v1
.
logging
.
info
(
"Finding best min_count to get target size of %d"
%
logging
.
info
(
"Finding best min_count to get target size of %d"
,
target_size
)
target_size
)
return
bisect
(
_MIN_MIN_COUNT
,
_MAX_MIN_COUNT
)
return
bisect
(
_MIN_MIN_COUNT
,
_MAX_MIN_COUNT
)
...
@@ -644,7 +642,7 @@ def _generate_subtokens(token_counts,
...
@@ -644,7 +642,7 @@ def _generate_subtokens(token_counts,
# subtoken_dict, count how often the resulting subtokens appear, and update
# subtoken_dict, count how often the resulting subtokens appear, and update
# the dictionary with subtokens w/ high enough counts.
# the dictionary with subtokens w/ high enough counts.
for
i
in
xrange
(
num_iterations
):
for
i
in
xrange
(
num_iterations
):
tf
.
compat
.
v1
.
logging
.
info
(
"
\t
Generating subtokens: iteration %d"
%
i
)
logging
.
info
(
"
\t
Generating subtokens: iteration %d"
,
i
)
# Generate new subtoken->id dictionary using the new subtoken list.
# Generate new subtoken->id dictionary using the new subtoken list.
subtoken_dict
=
_list_to_index_dict
(
subtoken_list
)
subtoken_dict
=
_list_to_index_dict
(
subtoken_list
)
...
@@ -658,5 +656,5 @@ def _generate_subtokens(token_counts,
...
@@ -658,5 +656,5 @@ def _generate_subtokens(token_counts,
subtoken_list
,
max_subtoken_length
=
_gen_new_subtoken_list
(
subtoken_list
,
max_subtoken_length
=
_gen_new_subtoken_list
(
subtoken_counts
,
min_count
,
alphabet
,
reserved_tokens
)
subtoken_counts
,
min_count
,
alphabet
,
reserved_tokens
)
tf
.
compat
.
v1
.
logging
.
info
(
"
\t
Vocab size: %d"
%
len
(
subtoken_list
))
logging
.
info
(
"
\t
Vocab size: %d"
,
len
(
subtoken_list
))
return
subtoken_list
return
subtoken_list
official/nlp/xlnet/training_utils.py
View file @
1f3247f4
...
@@ -28,7 +28,7 @@ from absl import logging
...
@@ -28,7 +28,7 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
typing
import
Any
,
Callable
,
Dict
,
Text
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
Text
,
Optional
from
official.
modeling
import
model_training_utils
from
official.
nlp.bert
import
model_training_utils
from
official.nlp.xlnet
import
data_utils
from
official.nlp.xlnet
import
data_utils
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
...
...
official/r1/mnist/mnist.py
View file @
1f3247f4
...
@@ -19,8 +19,9 @@ from __future__ import print_function
...
@@ -19,8 +19,9 @@ from __future__ import print_function
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.r1.mnist
import
dataset
from
official.r1.mnist
import
dataset
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
...
@@ -182,8 +183,8 @@ def run_mnist(flags_obj):
...
@@ -182,8 +183,8 @@ def run_mnist(flags_obj):
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
data_format
=
(
'channels_first'
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
else
'channels_last'
)
mnist_classifier
=
tf
.
estimator
.
Estimator
(
mnist_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_function
,
model_fn
=
model_function
,
model_dir
=
flags_obj
.
model_dir
,
model_dir
=
flags_obj
.
model_dir
,
...
@@ -241,6 +242,6 @@ def main(_):
...
@@ -241,6 +242,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_mnist_flags
()
define_mnist_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/r1/mnist/mnist_test.py
View file @
1f3247f4
...
@@ -21,7 +21,7 @@ import time
...
@@ -21,7 +21,7 @@ import time
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.mnist
import
mnist
from
official.r1.mnist
import
mnist
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -143,5 +143,5 @@ class Benchmarks(tf.test.Benchmark):
...
@@ -143,5 +143,5 @@ class Benchmarks(tf.test.Benchmark):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
tf
.
test
.
main
()
tf
.
test
.
main
()
official/r1/resnet/cifar10_main.py
View file @
1f3247f4
...
@@ -22,8 +22,9 @@ import os
...
@@ -22,8 +22,9 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.r1.resnet
import
resnet_model
from
official.r1.resnet
import
resnet_model
from
official.r1.resnet
import
resnet_run_loop
from
official.r1.resnet
import
resnet_run_loop
...
@@ -139,9 +140,9 @@ def input_fn(is_training,
...
@@ -139,9 +140,9 @@ def input_fn(is_training,
dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filenames
,
_RECORD_BYTES
)
dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filenames
,
_RECORD_BYTES
)
if
input_context
:
if
input_context
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
%
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
,
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
)
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
...
@@ -270,7 +271,7 @@ def run_cifar(flags_obj):
...
@@ -270,7 +271,7 @@ def run_cifar(flags_obj):
Dictionary of results. Including final accuracy.
Dictionary of results. Including final accuracy.
"""
"""
if
flags_obj
.
image_bytes_as_serving_input
:
if
flags_obj
.
image_bytes_as_serving_input
:
tf
.
compat
.
v1
.
logging
.
fatal
(
logging
.
fatal
(
'--image_bytes_as_serving_input cannot be set to True for CIFAR. '
'--image_bytes_as_serving_input cannot be set to True for CIFAR. '
'This flag is only applicable to ImageNet.'
)
'This flag is only applicable to ImageNet.'
)
return
return
...
@@ -291,6 +292,6 @@ def main(_):
...
@@ -291,6 +292,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_cifar_flags
()
define_cifar_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/r1/resnet/cifar10_test.py
View file @
1f3247f4
...
@@ -19,14 +19,15 @@ from __future__ import print_function
...
@@ -19,14 +19,15 @@ from __future__ import print_function
from
tempfile
import
mkstemp
from
tempfile
import
mkstemp
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.r1.resnet
import
cifar10_main
from
official.r1.resnet
import
cifar10_main
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
_BATCH_SIZE
=
128
_BATCH_SIZE
=
128
_HEIGHT
=
32
_HEIGHT
=
32
...
...
official/r1/resnet/estimator_benchmark.py
View file @
1f3247f4
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
import
time
import
time
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
...
@@ -56,7 +57,7 @@ class EstimatorBenchmark(tf.test.Benchmark):
...
@@ -56,7 +57,7 @@ class EstimatorBenchmark(tf.test.Benchmark):
def
_setup
(
self
):
def
_setup
(
self
):
"""Sets up and resets flags before each test."""
"""Sets up and resets flags before each test."""
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
if
EstimatorBenchmark
.
local_flags
is
None
:
if
EstimatorBenchmark
.
local_flags
is
None
:
for
flag_method
in
self
.
flag_methods
:
for
flag_method
in
self
.
flag_methods
:
flag_method
()
flag_method
()
...
...
official/r1/resnet/imagenet_main.py
View file @
1f3247f4
...
@@ -22,6 +22,7 @@ import os
...
@@ -22,6 +22,7 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -194,9 +195,9 @@ def input_fn(is_training,
...
@@ -194,9 +195,9 @@ def input_fn(is_training,
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
)
if
input_context
:
if
input_context
:
tf
.
compat
.
v1
.
logging
.
info
(
logging
.
info
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
%
(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d'
,
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
)
input_context
.
input_pipeline_id
,
input_context
.
num_input_pipelines
)
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
...
@@ -387,6 +388,6 @@ def main(_):
...
@@ -387,6 +388,6 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
define_imagenet_flags
()
define_imagenet_flags
()
absl_app
.
run
(
main
)
absl_app
.
run
(
main
)
official/r1/resnet/imagenet_test.py
View file @
1f3247f4
...
@@ -20,12 +20,13 @@ from __future__ import print_function
...
@@ -20,12 +20,13 @@ from __future__ import print_function
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.resnet
import
imagenet_main
from
official.r1.resnet
import
imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
_BATCH_SIZE
=
32
_BATCH_SIZE
=
32
_LABEL_CLASSES
=
1001
_LABEL_CLASSES
=
1001
...
...
official/r1/resnet/resnet_model.py
View file @
1f3247f4
...
@@ -391,8 +391,8 @@ class Model(object):
...
@@ -391,8 +391,8 @@ class Model(object):
self
.
resnet_size
=
resnet_size
self
.
resnet_size
=
resnet_size
if
not
data_format
:
if
not
data_format
:
data_format
=
(
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
else
'channels_last'
)
self
.
resnet_version
=
resnet_version
self
.
resnet_version
=
resnet_version
if
resnet_version
not
in
(
1
,
2
):
if
resnet_version
not
in
(
1
,
2
):
...
...
official/r1/resnet/resnet_run_loop.py
View file @
1f3247f4
...
@@ -29,6 +29,7 @@ import multiprocessing
...
@@ -29,6 +29,7 @@ import multiprocessing
import
os
import
os
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.r1.resnet
import
imagenet_preprocessing
from
official.r1.resnet
import
imagenet_preprocessing
...
@@ -83,8 +84,8 @@ def process_record_dataset(dataset,
...
@@ -83,8 +84,8 @@ def process_record_dataset(dataset,
options
.
experimental_threading
.
private_threadpool_size
=
(
options
.
experimental_threading
.
private_threadpool_size
=
(
datasets_num_private_threads
)
datasets_num_private_threads
)
dataset
=
dataset
.
with_options
(
options
)
dataset
=
dataset
.
with_options
(
options
)
tf
.
compat
.
v1
.
logging
.
info
(
'datasets_num_private_threads: %s'
,
logging
.
info
(
'datasets_num_private_threads: %s'
,
datasets_num_private_threads
)
datasets_num_private_threads
)
# Disable intra-op parallelism to optimize for throughput instead of latency.
# Disable intra-op parallelism to optimize for throughput instead of latency.
options
=
tf
.
data
.
Options
()
options
=
tf
.
data
.
Options
()
...
@@ -205,17 +206,15 @@ def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj):
...
@@ -205,17 +206,15 @@ def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj):
what has been set by the user on the command-line.
what has been set by the user on the command-line.
"""
"""
cpu_count
=
multiprocessing
.
cpu_count
()
cpu_count
=
multiprocessing
.
cpu_count
()
tf
.
compat
.
v1
.
logging
.
info
(
'Logical CPU cores: %s'
,
cpu_count
)
logging
.
info
(
'Logical CPU cores: %s'
,
cpu_count
)
# Sets up thread pool for each GPU for op scheduling.
# Sets up thread pool for each GPU for op scheduling.
per_gpu_thread_count
=
1
per_gpu_thread_count
=
1
total_gpu_thread_count
=
per_gpu_thread_count
*
flags_obj
.
num_gpus
total_gpu_thread_count
=
per_gpu_thread_count
*
flags_obj
.
num_gpus
os
.
environ
[
'TF_GPU_THREAD_MODE'
]
=
flags_obj
.
tf_gpu_thread_mode
os
.
environ
[
'TF_GPU_THREAD_MODE'
]
=
flags_obj
.
tf_gpu_thread_mode
os
.
environ
[
'TF_GPU_THREAD_COUNT'
]
=
str
(
per_gpu_thread_count
)
os
.
environ
[
'TF_GPU_THREAD_COUNT'
]
=
str
(
per_gpu_thread_count
)
tf
.
compat
.
v1
.
logging
.
info
(
'TF_GPU_THREAD_COUNT: %s'
,
logging
.
info
(
'TF_GPU_THREAD_COUNT: %s'
,
os
.
environ
[
'TF_GPU_THREAD_COUNT'
])
os
.
environ
[
'TF_GPU_THREAD_COUNT'
])
logging
.
info
(
'TF_GPU_THREAD_MODE: %s'
,
os
.
environ
[
'TF_GPU_THREAD_MODE'
])
tf
.
compat
.
v1
.
logging
.
info
(
'TF_GPU_THREAD_MODE: %s'
,
os
.
environ
[
'TF_GPU_THREAD_MODE'
])
# Reduces general thread pool by number of threads used for GPU pool.
# Reduces general thread pool by number of threads used for GPU pool.
main_thread_count
=
cpu_count
-
total_gpu_thread_count
main_thread_count
=
cpu_count
-
total_gpu_thread_count
...
@@ -648,7 +647,7 @@ def resnet_main(
...
@@ -648,7 +647,7 @@ def resnet_main(
hooks
=
train_hooks
,
hooks
=
train_hooks
,
max_steps
=
flags_obj
.
max_train_steps
)
max_steps
=
flags_obj
.
max_train_steps
)
eval_spec
=
tf
.
estimator
.
EvalSpec
(
input_fn
=
input_fn_eval
)
eval_spec
=
tf
.
estimator
.
EvalSpec
(
input_fn
=
input_fn_eval
)
tf
.
compat
.
v1
.
logging
.
info
(
'Starting to train and evaluate.'
)
logging
.
info
(
'Starting to train and evaluate.'
)
tf
.
estimator
.
train_and_evaluate
(
classifier
,
train_spec
,
eval_spec
)
tf
.
estimator
.
train_and_evaluate
(
classifier
,
train_spec
,
eval_spec
)
# tf.estimator.train_and_evalute doesn't return anything in multi-worker
# tf.estimator.train_and_evalute doesn't return anything in multi-worker
# case.
# case.
...
@@ -671,8 +670,7 @@ def resnet_main(
...
@@ -671,8 +670,7 @@ def resnet_main(
schedule
[
-
1
]
=
train_epochs
-
sum
(
schedule
[:
-
1
])
# over counting.
schedule
[
-
1
]
=
train_epochs
-
sum
(
schedule
[:
-
1
])
# over counting.
for
cycle_index
,
num_train_epochs
in
enumerate
(
schedule
):
for
cycle_index
,
num_train_epochs
in
enumerate
(
schedule
):
tf
.
compat
.
v1
.
logging
.
info
(
'Starting cycle: %d/%d'
,
cycle_index
,
logging
.
info
(
'Starting cycle: %d/%d'
,
cycle_index
,
int
(
n_loops
))
int
(
n_loops
))
if
num_train_epochs
:
if
num_train_epochs
:
# Since we are calling classifier.train immediately in each loop, the
# Since we are calling classifier.train immediately in each loop, the
...
@@ -691,7 +689,7 @@ def resnet_main(
...
@@ -691,7 +689,7 @@ def resnet_main(
# allows the eval (which is generally unimportant in those circumstances)
# allows the eval (which is generally unimportant in those circumstances)
# to terminate. Note that eval will run for max_train_steps each loop,
# to terminate. Note that eval will run for max_train_steps each loop,
# regardless of the global_step count.
# regardless of the global_step count.
tf
.
compat
.
v1
.
logging
.
info
(
'Starting to evaluate.'
)
logging
.
info
(
'Starting to evaluate.'
)
eval_results
=
classifier
.
evaluate
(
input_fn
=
input_fn_eval
,
eval_results
=
classifier
.
evaluate
(
input_fn
=
input_fn_eval
,
steps
=
flags_obj
.
max_train_steps
)
steps
=
flags_obj
.
max_train_steps
)
...
...
official/r1/utils/data/file_io.py
View file @
1f3247f4
...
@@ -25,10 +25,11 @@ import os
...
@@ -25,10 +25,11 @@ import os
import
tempfile
import
tempfile
import
uuid
import
uuid
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint:disable=logging-format-interpolation
class
_GarbageCollector
(
object
):
class
_GarbageCollector
(
object
):
...
@@ -50,9 +51,9 @@ class _GarbageCollector(object):
...
@@ -50,9 +51,9 @@ class _GarbageCollector(object):
for
i
in
self
.
temp_buffers
:
for
i
in
self
.
temp_buffers
:
if
tf
.
io
.
gfile
.
exists
(
i
):
if
tf
.
io
.
gfile
.
exists
(
i
):
tf
.
io
.
gfile
.
remove
(
i
)
tf
.
io
.
gfile
.
remove
(
i
)
tf
.
compat
.
v1
.
logging
.
info
(
"Buffer file {} removed"
.
format
(
i
))
logging
.
info
(
"Buffer file {} removed"
.
format
(
i
))
except
Exception
as
e
:
except
Exception
as
e
:
tf
.
compat
.
v1
.
logging
.
error
(
"Failed to cleanup buffer files: {}"
.
format
(
e
))
logging
.
error
(
"Failed to cleanup buffer files: {}"
.
format
(
e
))
_GARBAGE_COLLECTOR
=
_GarbageCollector
()
_GARBAGE_COLLECTOR
=
_GarbageCollector
()
...
@@ -176,7 +177,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
...
@@ -176,7 +177,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
actual_size
=
tf
.
io
.
gfile
.
stat
(
buffer_path
).
length
actual_size
=
tf
.
io
.
gfile
.
stat
(
buffer_path
).
length
if
expected_size
==
actual_size
:
if
expected_size
==
actual_size
:
return
buffer_path
return
buffer_path
tf
.
compat
.
v1
.
logging
.
warning
(
logging
.
warning
(
"Existing buffer {} has size {}. Expected size {}. Deleting and "
"Existing buffer {} has size {}. Expected size {}. Deleting and "
"rebuilding buffer."
.
format
(
buffer_path
,
actual_size
,
expected_size
))
"rebuilding buffer."
.
format
(
buffer_path
,
actual_size
,
expected_size
))
tf
.
io
.
gfile
.
remove
(
buffer_path
)
tf
.
io
.
gfile
.
remove
(
buffer_path
)
...
@@ -187,8 +188,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
...
@@ -187,8 +188,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
split
(
buffer_path
)[
0
])
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
split
(
buffer_path
)[
0
])
tf
.
compat
.
v1
.
logging
.
info
(
"Constructing TFRecordDataset buffer: {}"
logging
.
info
(
"Constructing TFRecordDataset buffer: {}"
.
format
(
buffer_path
))
.
format
(
buffer_path
))
count
=
0
count
=
0
pool
=
multiprocessing
.
dummy
.
Pool
(
multiprocessing
.
cpu_count
())
pool
=
multiprocessing
.
dummy
.
Pool
(
multiprocessing
.
cpu_count
())
...
@@ -198,10 +198,10 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
...
@@ -198,10 +198,10 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
rows_per_core
=
_ROWS_PER_CORE
):
rows_per_core
=
_ROWS_PER_CORE
):
_serialize_shards
(
df_shards
,
columns
,
pool
,
writer
)
_serialize_shards
(
df_shards
,
columns
,
pool
,
writer
)
count
+=
sum
([
len
(
s
)
for
s
in
df_shards
])
count
+=
sum
([
len
(
s
)
for
s
in
df_shards
])
tf
.
compat
.
v1
.
logging
.
info
(
"{}/{} examples written."
logging
.
info
(
"{}/{} examples written."
.
format
(
.
format
(
str
(
count
).
ljust
(
8
),
len
(
dataframe
)))
str
(
count
).
ljust
(
8
),
len
(
dataframe
)))
finally
:
finally
:
pool
.
terminate
()
pool
.
terminate
()
tf
.
compat
.
v1
.
logging
.
info
(
"Buffer write complete."
)
logging
.
info
(
"Buffer write complete."
)
return
buffer_path
return
buffer_path
official/r1/wide_deep/census_test.py
View file @
1f3247f4
...
@@ -21,13 +21,14 @@ import os
...
@@ -21,13 +21,14 @@ import os
import
unittest
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
from
official.r1.wide_deep
import
census_dataset
from
official.r1.wide_deep
import
census_dataset
from
official.r1.wide_deep
import
census_main
from
official.r1.wide_deep
import
census_main
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
TEST_INPUT
=
(
'18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
TEST_INPUT
=
(
'18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
'Husband,zyx,wvu,34,56,78,tsr,<=50K'
)
'Husband,zyx,wvu,34,56,78,tsr,<=50K'
)
...
...
official/r1/wide_deep/movielens_test.py
View file @
1f3247f4
...
@@ -28,8 +28,9 @@ from official.utils.misc import keras_utils
...
@@ -28,8 +28,9 @@ from official.utils.misc import keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
from
official.r1.wide_deep
import
movielens_dataset
from
official.r1.wide_deep
import
movielens_dataset
from
official.r1.wide_deep
import
movielens_main
from
official.r1.wide_deep
import
movielens_main
from
absl
import
logging
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
ERROR
)
logging
.
set_verbosity
(
logging
.
ERROR
)
TEST_INPUT_VALUES
=
{
TEST_INPUT_VALUES
=
{
...
...
official/requirements.txt
View file @
1f3247f4
...
@@ -11,7 +11,8 @@ py-cpuinfo>=3.3.0
...
@@ -11,7 +11,8 @@ py-cpuinfo>=3.3.0
scipy>=0.19.1
scipy>=0.19.1
tensorflow-hub>=0.6.0
tensorflow-hub>=0.6.0
tensorflow-model-optimization>=0.2.1
tensorflow-model-optimization>=0.2.1
tensorflow_datasets
tensorflow-datasets
tensorflow-addons
dataclasses
dataclasses
gin-config
gin-config
typing
typing
...
...
official/staging/training/controller.py
View file @
1f3247f4
...
@@ -117,11 +117,18 @@ class Controller(object):
...
@@ -117,11 +117,18 @@ class Controller(object):
if
self
.
train_fn
is
not
None
:
if
self
.
train_fn
is
not
None
:
self
.
train_steps
=
train_steps
self
.
train_steps
=
train_steps
self
.
steps_per_loop
=
steps_per_loop
self
.
steps_per_loop
=
steps_per_loop
self
.
summary_dir
=
summary_dir
or
checkpoint_manager
.
directory
if
summary_dir
:
self
.
summary_dir
=
summary_dir
elif
checkpoint_manager
:
self
.
summary_dir
=
checkpoint_manager
.
directory
else
:
self
.
summary_dir
=
None
self
.
summary_interval
=
summary_interval
self
.
summary_interval
=
summary_interval
summary_writer
=
tf
.
summary
.
create_file_writer
(
if
self
.
summary_dir
and
self
.
summary_interval
:
self
.
summary_dir
)
if
self
.
summary_interval
else
None
summary_writer
=
tf
.
summary
.
create_file_writer
(
self
.
summary_dir
)
else
:
summary_writer
=
None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
# maximum customizability.
self
.
summary_manager
=
utils
.
SummaryManager
(
self
.
summary_manager
=
utils
.
SummaryManager
(
...
@@ -140,14 +147,14 @@ class Controller(object):
...
@@ -140,14 +147,14 @@ class Controller(object):
self
.
eval_steps
=
eval_steps
self
.
eval_steps
=
eval_steps
self
.
eval_interval
=
eval_interval
self
.
eval_interval
=
eval_interval
# Create and initialize the interval triggers.
# Create
s
and initialize
s
the interval triggers.
self
.
eval_trigger
=
utils
.
IntervalTrigger
(
self
.
eval_interval
,
self
.
eval_trigger
=
utils
.
IntervalTrigger
(
self
.
eval_interval
,
self
.
global_step
.
numpy
())
self
.
global_step
.
numpy
())
# pytype: disable=attribute-error
if
self
.
global_step
:
if
self
.
global_step
:
tf
.
summary
.
experimental
.
set_step
(
self
.
global_step
)
tf
.
summary
.
experimental
.
set_step
(
self
.
global_step
)
# Restore
M
odel if needed.
# Restore
s the m
odel if needed.
if
self
.
checkpoint_manager
is
not
None
:
if
self
.
checkpoint_manager
is
not
None
:
model_restored
=
self
.
_restore_model
()
model_restored
=
self
.
_restore_model
()
if
not
model_restored
and
self
.
checkpoint_manager
.
checkpoint_interval
:
if
not
model_restored
and
self
.
checkpoint_manager
.
checkpoint_interval
:
...
@@ -192,7 +199,7 @@ class Controller(object):
...
@@ -192,7 +199,7 @@ class Controller(object):
self
.
eval_summary_manager
.
flush
()
self
.
eval_summary_manager
.
flush
()
def
_maybe_save_checkpoints
(
self
,
current_step
,
force_trigger
=
False
):
def
_maybe_save_checkpoints
(
self
,
current_step
,
force_trigger
=
False
):
if
self
.
checkpoint_manager
.
checkpoint_interval
:
if
self
.
checkpoint_manager
and
self
.
checkpoint_manager
.
checkpoint_interval
:
ckpt_path
=
self
.
checkpoint_manager
.
save
(
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
current_step
,
check_interval
=
not
force_trigger
)
checkpoint_number
=
current_step
,
check_interval
=
not
force_trigger
)
if
ckpt_path
is
not
None
:
if
ckpt_path
is
not
None
:
...
...
official/staging/training/controller_test.py
View file @
1f3247f4
...
@@ -143,6 +143,52 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -143,6 +143,52 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
super
(
ControllerTest
,
self
).
setUp
()
super
(
ControllerTest
,
self
).
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
test_runnable
=
TestRunnable
()
# No checkpoint manager and no strategy.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# No checkpoint, so global step starts from 0.
test_runnable
.
global_step
.
assign
(
0
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
def
test_no_checkpoint_and_summaries
(
self
):
test_runnable
=
TestRunnable
()
# No checkpoint + summary directories.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_and_evaluate
(
self
,
strategy
):
def
test_train_and_evaluate
(
self
,
strategy
):
with
strategy
.
scope
():
with
strategy
.
scope
():
...
...
official/staging/training/grad_utils.py
View file @
1f3247f4
...
@@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars,
...
@@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars,
This utils function is used when users intent to explicitly allreduce
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
The allreduced gradients are then passed to optimizer.apply_gradients(
all_reduce_sum
_gradients=False).
experimental_aggregate
_gradients=False).
Arguments:
Arguments:
grads_and_vars: gradients and variables pairs.
grads_and_vars: gradients and variables pairs.
...
@@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape,
...
@@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape,
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
if
post_allreduce_callbacks
:
if
post_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
post_allreduce_callbacks
,
grads_and_vars
)
grads_and_vars
=
_run_callbacks
(
post_allreduce_callbacks
,
grads_and_vars
)
optimizer
.
apply_gradients
(
grads_and_vars
,
all_reduce_sum_gradients
=
False
)
optimizer
.
apply_gradients
(
grads_and_vars
,
experimental_aggregate_gradients
=
False
)
official/utils/flags/_device.py
View file @
1f3247f4
...
@@ -19,7 +19,7 @@ from __future__ import division
...
@@ -19,7 +19,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
from
absl
import
logging
from
official.utils.flags._conventions
import
help_wrap
from
official.utils.flags._conventions
import
help_wrap
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
valid_flags
=
True
valid_flags
=
True
for
key
in
flag_names
:
for
key
in
flag_names
:
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
tf
.
compat
.
v1
.
logging
.
error
(
"
{}
must be a GCS path."
.
format
(
key
)
)
logging
.
error
(
"
%s
must be a GCS path."
,
key
)
valid_flags
=
False
valid_flags
=
False
return
valid_flags
return
valid_flags
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment