Commit a8b5cb7a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 305897677
parent c0c58423
...@@ -24,6 +24,7 @@ from absl import logging ...@@ -24,6 +24,7 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.benchmark.models import resnet_cifar_model from official.benchmark.models import resnet_cifar_model
from official.benchmark.models import synthetic_util
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -159,7 +160,7 @@ def run(flags_obj): ...@@ -159,7 +160,7 @@ def run(flags_obj):
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() synthetic_util.set_up_synthetic_data()
input_fn = common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT, height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH, width=cifar_preprocessing.WIDTH,
...@@ -168,7 +169,7 @@ def run(flags_obj): ...@@ -168,7 +169,7 @@ def run(flags_obj):
dtype=flags_core.get_tf_dtype(flags_obj), dtype=flags_core.get_tf_dtype(flags_obj),
drop_remainder=True) drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data() synthetic_util.undo_set_up_synthetic_data()
input_fn = cifar_preprocessing.input_fn input_fn = cifar_preprocessing.input_fn
train_input_dataset = input_fn( train_input_dataset = input_fn(
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import string
from absl import logging
import tensorflow as tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
...@@ -40,7 +40,7 @@ def _collective_communication(all_reduce_alg): ...@@ -40,7 +40,7 @@ def _collective_communication(all_reduce_alg):
tf.distribute.experimental.CollectiveCommunication object tf.distribute.experimental.CollectiveCommunication object
Raises: Raises:
ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl'] ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
""" """
collective_communication_options = { collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO, None: tf.distribute.experimental.CollectiveCommunication.AUTO,
...@@ -50,7 +50,7 @@ def _collective_communication(all_reduce_alg): ...@@ -50,7 +50,7 @@ def _collective_communication(all_reduce_alg):
if all_reduce_alg not in collective_communication_options: if all_reduce_alg not in collective_communication_options:
raise ValueError( raise ValueError(
"When used with `multi_worker_mirrored`, valid values for " "When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are ['ring', 'nccl']. Supplied value: {}".format( "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg)) all_reduce_alg))
return collective_communication_options[all_reduce_alg] return collective_communication_options[all_reduce_alg]
...@@ -66,7 +66,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs): ...@@ -66,7 +66,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
tf.distribute.CrossDeviceOps object or None. tf.distribute.CrossDeviceOps object or None.
Raises: Raises:
ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy']. ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
""" """
if all_reduce_alg is None: if all_reduce_alg is None:
return None return None
...@@ -77,7 +77,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs): ...@@ -77,7 +77,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
if all_reduce_alg not in mirrored_all_reduce_options: if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError( raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are " "When used with `mirrored`, valid values for all_reduce_alg are "
"['nccl', 'hierarchical_copy']. Supplied value: {}".format( "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg)) all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg] cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs) return cross_device_ops_class(num_packs=num_packs)
...@@ -92,9 +92,9 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -92,9 +92,9 @@ def get_distribution_strategy(distribution_strategy="mirrored",
Args: Args:
distribution_strategy: a string specifying which distribution strategy to distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored', use. Accepted values are "off", "one_device", "mirrored",
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive. "parameter_server", "multi_worker_mirrored", and "tpu" -- case insensitive.
'off' means not to use Distribution Strategy; 'tpu' means to use "off" means not to use Distribution Strategy; "tpu" means to use
TPUStrategy using `tpu_address`. TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
...@@ -109,7 +109,7 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -109,7 +109,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
Raises: Raises:
ValueError: if `distribution_strategy` is 'off' or 'one_device' and ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if `num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified. `distribution_strategy` is `tpu` but `tpu_address` is not specified.
""" """
...@@ -121,7 +121,7 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -121,7 +121,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus > 1: if num_gpus > 1:
raise ValueError( raise ValueError(
"When {} GPUs are specified, distribution_strategy " "When {} GPUs are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus)) "flag cannot be set to `off`.".format(num_gpus))
return None return None
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
...@@ -157,110 +157,6 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -157,110 +157,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
"Unrecognized Distribution Strategy: %r" % distribution_strategy) "Unrecognized Distribution Strategy: %r" % distribution_strategy)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def configure_cluster(worker_hosts=None, task_index=-1): def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable. """Set multi-worker cluster spec in TF_CONFIG environment variable.
...@@ -270,21 +166,21 @@ def configure_cluster(worker_hosts=None, task_index=-1): ...@@ -270,21 +166,21 @@ def configure_cluster(worker_hosts=None, task_index=-1):
Returns: Returns:
Number of workers in the cluster. Number of workers in the cluster.
""" """
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}')) tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config: if tf_config:
num_workers = (len(tf_config['cluster'].get('chief', [])) + num_workers = (len(tf_config["cluster"].get("chief", [])) +
len(tf_config['cluster'].get('worker', []))) len(tf_config["cluster"].get("worker", [])))
elif worker_hosts: elif worker_hosts:
workers = worker_hosts.split(',') workers = worker_hosts.split(",")
num_workers = len(workers) num_workers = len(workers)
if num_workers > 1 and task_index < 0: if num_workers > 1 and task_index < 0:
raise ValueError('Must specify task_index when number of workers > 1') raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index task_index = 0 if num_workers == 1 else task_index
os.environ['TF_CONFIG'] = json.dumps({ os.environ["TF_CONFIG"] = json.dumps({
'cluster': { "cluster": {
'worker': workers "worker": workers
}, },
'task': {'type': 'worker', 'index': task_index} "task": {"type": "worker", "index": task_index}
}) })
else: else:
num_workers = 1 num_workers = 1
......
...@@ -98,7 +98,6 @@ def run(flags_obj): ...@@ -98,7 +98,6 @@ def run(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
...@@ -107,7 +106,6 @@ def run(flags_obj): ...@@ -107,7 +106,6 @@ def run(flags_obj):
dtype=dtype, dtype=dtype,
drop_remainder=True) drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_preprocessing.input_fn input_fn = imagenet_preprocessing.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches # When `enable_xla` is True, we always drop the remainder of the batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment