Commit ad09cf49 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Check if there are GPUs instead of if TF is built with CUDA support.

The TF pip packages are always built with CUDA support, so tf.test.is_built_with_cuda() would return True even if the user had no GPU.

PiperOrigin-RevId: 302928378
parent 7c83a9d7
...@@ -138,8 +138,8 @@ def run(flags_obj): ...@@ -138,8 +138,8 @@ def run(flags_obj):
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
......
...@@ -415,7 +415,7 @@ def run_customized_training_loop( ...@@ -415,7 +415,7 @@ def run_customized_training_loop(
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if tf.test.is_built_with_cuda(): if tf.config.list_physical_devices('GPU'):
# TODO(zongweiz): merge with train_steps once tf.while_loop # TODO(zongweiz): merge with train_steps once tf.while_loop
# GPU performance bugs are fixed. # GPU performance bugs are fixed.
for _ in range(steps): for _ in range(steps):
......
...@@ -183,8 +183,8 @@ def run_mnist(flags_obj): ...@@ -183,8 +183,8 @@ def run_mnist(flags_obj):
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=model_function, model_fn=model_function,
model_dir=flags_obj.model_dir, model_dir=flags_obj.model_dir,
......
...@@ -391,8 +391,8 @@ class Model(object): ...@@ -391,8 +391,8 @@ class Model(object):
self.resnet_size = resnet_size self.resnet_size = resnet_size
if not data_format: if not data_format:
data_format = ( data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
self.resnet_version = resnet_version self.resnet_version = resnet_version
if resnet_version not in (1, 2): if resnet_version not in (1, 2):
......
...@@ -119,8 +119,8 @@ def run(flags_obj): ...@@ -119,8 +119,8 @@ def run(flags_obj):
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
......
...@@ -71,8 +71,8 @@ def run(flags_obj): ...@@ -71,8 +71,8 @@ def run(flags_obj):
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy. # Configures cluster spec for distribution strategy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment