Commit ad09cf49 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Check if there are GPUs instead of if TF is built with CUDA support.

The TF pip packages are always built with CUDA support, so tf.test.is_built_with_cuda() would return True even if the user had no GPU.

PiperOrigin-RevId: 302928378
parent 7c83a9d7
......@@ -138,8 +138,8 @@ def run(flags_obj):
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
......
......@@ -415,7 +415,7 @@ def run_customized_training_loop(
# Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if tf.test.is_built_with_cuda():
if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed.
for _ in range(steps):
......
......@@ -183,8 +183,8 @@ def run_mnist(flags_obj):
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
......
......@@ -391,8 +391,8 @@ class Model(object):
self.resnet_size = resnet_size
if not data_format:
data_format = (
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
self.resnet_version = resnet_version
if resnet_version not in (1, 2):
......
......@@ -119,8 +119,8 @@ def run(flags_obj):
# TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
......
......@@ -71,8 +71,8 @@ def run(flags_obj):
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment