Commit a8b5cb7a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 305897677
parent c0c58423
......@@ -24,6 +24,7 @@ from absl import logging
import numpy as np
import tensorflow as tf
from official.benchmark.models import resnet_cifar_model
from official.benchmark.models import synthetic_util
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
......@@ -159,7 +160,7 @@ def run(flags_obj):
strategy_scope = distribution_utils.get_strategy_scope(strategy)
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
synthetic_util.set_up_synthetic_data()
input_fn = common.get_synth_input_fn(
height=cifar_preprocessing.HEIGHT,
width=cifar_preprocessing.WIDTH,
......@@ -168,7 +169,7 @@ def run(flags_obj):
dtype=flags_core.get_tf_dtype(flags_obj),
drop_remainder=True)
else:
distribution_utils.undo_set_up_synthetic_data()
synthetic_util.undo_set_up_synthetic_data()
input_fn = cifar_preprocessing.input_fn
train_input_dataset = input_fn(
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to generate data directly on devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import string
from absl import logging
import tensorflow as tf
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
......@@ -40,7 +40,7 @@ def _collective_communication(all_reduce_alg):
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
......@@ -50,7 +50,7 @@ def _collective_communication(all_reduce_alg):
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are ['ring', 'nccl']. Supplied value: {}".format(
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
......@@ -66,7 +66,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
......@@ -77,7 +77,7 @@ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"['nccl', 'hierarchical_copy']. Supplied value: {}".format(
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
......@@ -92,9 +92,9 @@ def get_distribution_strategy(distribution_strategy="mirrored",
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case insensitive.
'off' means not to use Distribution Strategy; 'tpu' means to use
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case insensitive.
"off" means not to use Distribution Strategy; "tpu" means to use
TPUStrategy using `tpu_address`.
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
......@@ -109,7 +109,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is 'off' or 'one_device' and
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
......@@ -121,7 +121,7 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus > 1:
raise ValueError(
"When {} GPUs are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus))
"flag cannot be set to `off`.".format(num_gpus))
return None
if distribution_strategy == "tpu":
......@@ -157,110 +157,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
"Unrecognized Distribution Strategy: %r" % distribution_strategy)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
# later.
class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, dataset, split_by=1):
# dataset.take(1) doesn't have GPU kernel.
with tf.device('device:CPU:0'):
tensor = tf.data.experimental.get_single_element(dataset.take(1))
flat_tensor = tf.nest.flatten(tensor)
variable_data = []
initializers = []
for t in flat_tensor:
rebatched_t = tf.split(t, num_or_size_splits=split_by, axis=0)[0]
assert rebatched_t.shape.is_fully_defined(), rebatched_t.shape
v = tf.compat.v1.get_local_variable(self._random_name(),
initializer=rebatched_t)
variable_data.append(v)
initializers.append(v.initializer)
input_data = tf.nest.pack_sequence_as(tensor, variable_data)
self._iterator = SyntheticIterator(input_data, initializers)
def _random_name(self, size=10, chars=string.ascii_uppercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def __iter__(self):
return self._iterator
def make_one_shot_iterator(self):
return self._iterator
def make_initializable_iterator(self):
return self._iterator
class SyntheticIterator(object):
"""A dataset that generates synthetic data on each device."""
def __init__(self, input_data, initializers):
self._input_data = input_data
self._initializers = initializers
def get_next(self):
return self._input_data
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except tf.errors.OutOfRangeError:
raise StopIteration
def initialize(self):
if tf.executing_eagerly():
return tf.no_op()
else:
return self._initializers
def _monkey_patch_dataset_method(strategy):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def make_dataset(self, dataset):
logging.info('Using pure synthetic data.')
with self.scope():
if self.extended._global_batch_size: # pylint: disable=protected-access
return SyntheticDataset(dataset, self.num_replicas_in_sync)
else:
return SyntheticDataset(dataset)
def make_iterator(self, dataset):
dist_dataset = make_dataset(self, dataset)
return iter(dist_dataset)
strategy.orig_make_dataset_iterator = strategy.make_dataset_iterator
strategy.make_dataset_iterator = make_iterator
strategy.orig_distribute_dataset = strategy.experimental_distribute_dataset
strategy.experimental_distribute_dataset = make_dataset
def _undo_monkey_patch_dataset_method(strategy):
if hasattr(strategy, 'orig_make_dataset_iterator'):
strategy.make_dataset_iterator = strategy.orig_make_dataset_iterator
if hasattr(strategy, 'orig_distribute_dataset'):
strategy.make_dataset_iterator = strategy.orig_distribute_dataset
def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
......@@ -270,21 +166,21 @@ def configure_cluster(worker_hosts=None, task_index=-1):
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (len(tf_config['cluster'].get('chief', [])) +
len(tf_config['cluster'].get('worker', [])))
num_workers = (len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(',')
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError('Must specify task_index when number of workers > 1')
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': workers
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
'task': {'type': 'worker', 'index': task_index}
"task": {"type": "worker", "index": task_index}
})
else:
num_workers = 1
......
......@@ -98,7 +98,6 @@ def run(flags_obj):
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data()
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
......@@ -107,7 +106,6 @@ def run(flags_obj):
dtype=dtype,
drop_remainder=True)
else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_preprocessing.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment