Unverified Commit 1f3247f4 authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #6 from tensorflow/master

Updated
parents 370a4c8d 0265f59c
This diff is collapsed.
...@@ -119,8 +119,8 @@ def run(flags_obj): ...@@ -119,8 +119,8 @@ def run(flags_obj):
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
......
...@@ -71,8 +71,8 @@ def run(flags_obj): ...@@ -71,8 +71,8 @@ def run(flags_obj):
data_format = flags_obj.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
if tf.test.is_built_with_cuda() else 'channels_last') else 'channels_last')
tf.keras.backend.set_image_data_format(data_format) tf.keras.backend.set_image_data_format(data_format)
# Configures cluster spec for distribution strategy. # Configures cluster spec for distribution strategy.
......
...@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable, ...@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
loss = tf.reduce_sum(prediction_loss) * (1.0 / loss = tf.reduce_sum(prediction_loss) * (1.0 /
self.flags_obj.batch_size) self.flags_obj.batch_size)
num_replicas = self.strategy.num_replicas_in_sync num_replicas = self.strategy.num_replicas_in_sync
l2_weight_decay = 1e-4
if self.flags_obj.single_l2_loss_op: if self.flags_obj.single_l2_loss_op:
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([ l2_loss = l2_weight_decay * 2 * tf.add_n([
tf.nn.l2_loss(v) tf.nn.l2_loss(v)
for v in self.model.trainable_variables for v in self.model.trainable_variables
if 'bn' not in v.name if 'bn' not in v.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment