Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a8b5cb7a
Commit
a8b5cb7a
authored
Apr 10, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Apr 10, 2020
Browse files
Internal change
PiperOrigin-RevId: 305897677
parent
c0c58423
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
150 additions
and
126 deletions
+150
-126
official/benchmark/models/resnet_cifar_main.py
official/benchmark/models/resnet_cifar_main.py
+3
-2
official/benchmark/models/synthetic_util.py
official/benchmark/models/synthetic_util.py
+129
-0
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+18
-122
official/vision/image_classification/resnet/resnet_imagenet_main.py
...ision/image_classification/resnet/resnet_imagenet_main.py
+0
-2
No files found.
official/benchmark/models/resnet_cifar_main.py
View file @
a8b5cb7a
...
...
@@ -24,6 +24,7 @@ from absl import logging
import
numpy
as
np
import
tensorflow
as
tf
from
official.benchmark.models
import
resnet_cifar_model
from
official.benchmark.models
import
synthetic_util
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
...
...
@@ -159,7 +160,7 @@ def run(flags_obj):
strategy_scope
=
distribution_utils
.
get_strategy_scope
(
strategy
)
if
flags_obj
.
use_synthetic_data
:
distribution
_util
s
.
set_up_synthetic_data
()
synthetic
_util
.
set_up_synthetic_data
()
input_fn
=
common
.
get_synth_input_fn
(
height
=
cifar_preprocessing
.
HEIGHT
,
width
=
cifar_preprocessing
.
WIDTH
,
...
...
@@ -168,7 +169,7 @@ def run(flags_obj):
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
),
drop_remainder
=
True
)
else
:
distribution
_util
s
.
undo_set_up_synthetic_data
()
synthetic
_util
.
undo_set_up_synthetic_data
()
input_fn
=
cifar_preprocessing
.
input_fn
train_input_dataset
=
input_fn
(
...
...
official/benchmark/models/synthetic_util.py
0 → 100644
View file @
a8b5cb7a
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
random
import
string
from
absl
import
logging
import
tensorflow
as
tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class
SyntheticDataset
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
dataset
,
split_by
=
1
):
# dataset.take(1) doesn't have GPU kernel.
with
tf
.
device
(
'device:CPU:0'
):
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
variable_data
=
[]
initializers
=
[]
for
t
in
flat_tensor
:
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
_random_name
(),
initializer
=
rebatched_t
)
variable_data
.
append
(
v
)
initializers
.
append
(
v
.
initializer
)
input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
self
.
_iterator
=
SyntheticIterator
(
input_data
,
initializers
)
def
_random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
__iter__
(
self
):
return
self
.
_iterator
def
make_one_shot_iterator
(
self
):
return
self
.
_iterator
def
make_initializable_iterator
(
self
):
return
self
.
_iterator
class
SyntheticIterator
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
input_data
,
initializers
):
self
.
_input_data
=
input_data
self
.
_initializers
=
initializers
def
get_next
(
self
):
return
self
.
_input_data
def
next
(
self
):
return
self
.
__next__
()
def
__next__
(
self
):
try
:
return
self
.
get_next
()
except
tf
.
errors
.
OutOfRangeError
:
raise
StopIteration
def
initialize
(
self
):
if
tf
.
executing_eagerly
():
return
tf
.
no_op
()
else
:
return
self
.
_initializers
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
(
self
,
dataset
):
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
return
SyntheticDataset
(
dataset
,
self
.
num_replicas_in_sync
)
else
:
return
SyntheticDataset
(
dataset
)
def
make_iterator
(
self
,
dataset
):
dist_dataset
=
make_dataset
(
self
,
dataset
)
return
iter
(
dist_dataset
)
strategy
.
orig_make_dataset_iterator
=
strategy
.
make_dataset_iterator
strategy
.
make_dataset_iterator
=
make_iterator
strategy
.
orig_distribute_dataset
=
strategy
.
experimental_distribute_dataset
strategy
.
experimental_distribute_dataset
=
make_dataset
def
_undo_monkey_patch_dataset_method
(
strategy
):
if
hasattr
(
strategy
,
'orig_make_dataset_iterator'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_make_dataset_iterator
if
hasattr
(
strategy
,
'orig_distribute_dataset'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_distribute_dataset
def
set_up_synthetic_data
():
_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
def
undo_set_up_synthetic_data
():
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
official/utils/misc/distribution_utils.py
View file @
a8b5cb7a
...
...
@@ -40,7 +40,7 @@ def _collective_communication(all_reduce_alg):
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None,
'
ring
'
,
'
nccl
'
]
ValueError: if `all_reduce_alg` not in [None,
"
ring
"
,
"
nccl
"
]
"""
collective_communication_options
=
{
None
:
tf
.
distribute
.
experimental
.
CollectiveCommunication
.
AUTO
,
...
...
@@ -50,7 +50,7 @@ def _collective_communication(all_reduce_alg):
if
all_reduce_alg
not
in
collective_communication_options
:
raise
ValueError
(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [
'
ring
'
,
'
nccl
'
]. Supplied value: {}"
.
format
(
"all_reduce_alg are [
`
ring
`
,
`
nccl
`
]. Supplied value: {}"
.
format
(
all_reduce_alg
))
return
collective_communication_options
[
all_reduce_alg
]
...
...
@@ -66,7 +66,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None,
'
nccl
'
,
'
hierarchical_copy
'
].
ValueError: if `all_reduce_alg` not in [None,
"
nccl
"
,
"
hierarchical_copy
"
].
"""
if
all_reduce_alg
is
None
:
return
None
...
...
@@ -77,7 +77,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
if
all_reduce_alg
not
in
mirrored_all_reduce_options
:
raise
ValueError
(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[
'
nccl
'
,
'
hierarchical_copy
'
]. Supplied value: {}"
.
format
(
"[
`
nccl
`
,
`
hierarchical_copy
`
]. Supplied value: {}"
.
format
(
all_reduce_alg
))
cross_device_ops_class
=
mirrored_all_reduce_options
[
all_reduce_alg
]
return
cross_device_ops_class
(
num_packs
=
num_packs
)
...
...
@@ -92,9 +92,9 @@ def get_distribution_strategy(distribution_strategy="mirrored",
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are
'
off
'
,
'
one_device
'
,
'
mirrored
'
,
'
parameter_server
'
,
'
multi_worker_mirrored
'
, and
'
tpu
'
-- case insensitive.
'
off
'
means not to use Distribution Strategy;
'
tpu
'
means to use
use. Accepted values are
"
off
"
,
"
one_device
"
,
"
mirrored
"
,
"
parameter_server
"
,
"
multi_worker_mirrored
"
, and
"
tpu
"
-- case insensitive.
"
off
"
means not to use Distribution Strategy;
"
tpu
"
means to use
TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
...
...
@@ -109,7 +109,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is
'
off
'
or
'
one_device
'
and
ValueError: if `distribution_strategy` is
"
off
"
or
"
one_device
"
and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
...
...
@@ -121,7 +121,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if
num_gpus
>
1
:
raise
ValueError
(
"When {} GPUs are specified, distribution_strategy "
"flag cannot be set to
'
off
'
."
.
format
(
num_gpus
))
"flag cannot be set to
`
off
`
."
.
format
(
num_gpus
))
return
None
if
distribution_strategy
==
"tpu"
:
...
...
@@ -157,110 +157,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
"Unrecognized Distribution Strategy: %r"
%
distribution_strategy
)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class
SyntheticDataset
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
dataset
,
split_by
=
1
):
# dataset.take(1) doesn't have GPU kernel.
with
tf
.
device
(
'device:CPU:0'
):
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
variable_data
=
[]
initializers
=
[]
for
t
in
flat_tensor
:
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
_random_name
(),
initializer
=
rebatched_t
)
variable_data
.
append
(
v
)
initializers
.
append
(
v
.
initializer
)
input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
self
.
_iterator
=
SyntheticIterator
(
input_data
,
initializers
)
def
_random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
__iter__
(
self
):
return
self
.
_iterator
def
make_one_shot_iterator
(
self
):
return
self
.
_iterator
def
make_initializable_iterator
(
self
):
return
self
.
_iterator
class
SyntheticIterator
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
input_data
,
initializers
):
self
.
_input_data
=
input_data
self
.
_initializers
=
initializers
def
get_next
(
self
):
return
self
.
_input_data
def
next
(
self
):
return
self
.
__next__
()
def
__next__
(
self
):
try
:
return
self
.
get_next
()
except
tf
.
errors
.
OutOfRangeError
:
raise
StopIteration
def
initialize
(
self
):
if
tf
.
executing_eagerly
():
return
tf
.
no_op
()
else
:
return
self
.
_initializers
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
(
self
,
dataset
):
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
return
SyntheticDataset
(
dataset
,
self
.
num_replicas_in_sync
)
else
:
return
SyntheticDataset
(
dataset
)
def
make_iterator
(
self
,
dataset
):
dist_dataset
=
make_dataset
(
self
,
dataset
)
return
iter
(
dist_dataset
)
strategy
.
orig_make_dataset_iterator
=
strategy
.
make_dataset_iterator
strategy
.
make_dataset_iterator
=
make_iterator
strategy
.
orig_distribute_dataset
=
strategy
.
experimental_distribute_dataset
strategy
.
experimental_distribute_dataset
=
make_dataset
def
_undo_monkey_patch_dataset_method
(
strategy
):
if
hasattr
(
strategy
,
'orig_make_dataset_iterator'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_make_dataset_iterator
if
hasattr
(
strategy
,
'orig_distribute_dataset'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_distribute_dataset
def
set_up_synthetic_data
():
_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
def
undo_set_up_synthetic_data
():
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
OneDeviceStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
MirroredStrategy
)
_undo_monkey_patch_dataset_method
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
def
configure_cluster
(
worker_hosts
=
None
,
task_index
=-
1
):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
...
...
@@ -270,21 +166,21 @@ def configure_cluster(worker_hosts=None, task_index=-1):
Returns:
Number of workers in the cluster.
"""
tf_config
=
json
.
loads
(
os
.
environ
.
get
(
'
TF_CONFIG
'
,
'
{}
'
))
tf_config
=
json
.
loads
(
os
.
environ
.
get
(
"
TF_CONFIG
"
,
"
{}
"
))
if
tf_config
:
num_workers
=
(
len
(
tf_config
[
'
cluster
'
].
get
(
'
chief
'
,
[]))
+
len
(
tf_config
[
'
cluster
'
].
get
(
'
worker
'
,
[])))
num_workers
=
(
len
(
tf_config
[
"
cluster
"
].
get
(
"
chief
"
,
[]))
+
len
(
tf_config
[
"
cluster
"
].
get
(
"
worker
"
,
[])))
elif
worker_hosts
:
workers
=
worker_hosts
.
split
(
','
)
workers
=
worker_hosts
.
split
(
","
)
num_workers
=
len
(
workers
)
if
num_workers
>
1
and
task_index
<
0
:
raise
ValueError
(
'
Must specify task_index when number of workers > 1
'
)
raise
ValueError
(
"
Must specify task_index when number of workers > 1
"
)
task_index
=
0
if
num_workers
==
1
else
task_index
os
.
environ
[
'
TF_CONFIG
'
]
=
json
.
dumps
({
'
cluster
'
:
{
'
worker
'
:
workers
os
.
environ
[
"
TF_CONFIG
"
]
=
json
.
dumps
({
"
cluster
"
:
{
"
worker
"
:
workers
},
'
task
'
:
{
'
type
'
:
'
worker
'
,
'
index
'
:
task_index
}
"
task
"
:
{
"
type
"
:
"
worker
"
,
"
index
"
:
task_index
}
})
else
:
num_workers
=
1
...
...
official/vision/image_classification/resnet/resnet_imagenet_main.py
View file @
a8b5cb7a
...
...
@@ -98,7 +98,6 @@ def run(flags_obj):
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
distribution_utils
.
set_up_synthetic_data
()
input_fn
=
common
.
get_synth_input_fn
(
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
...
...
@@ -107,7 +106,6 @@ def run(flags_obj):
dtype
=
dtype
,
drop_remainder
=
True
)
else
:
distribution_utils
.
undo_set_up_synthetic_data
()
input_fn
=
imagenet_preprocessing
.
input_fn
# When `enable_xla` is True, we always drop the remainder of the batches
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment