Commit 03c35ec6 authored by Shining Sun's avatar Shining Sun
Browse files

Fixed lint and flag issues

parent 7522e4dc
...@@ -52,7 +52,7 @@ class KerasCifar10BenchmarkTests(object): ...@@ -52,7 +52,7 @@ class KerasCifar10BenchmarkTests(object):
def keras_resnet56_no_dist_strat_1_gpu(self): def keras_resnet56_no_dist_strat_1_gpu(self):
"""Test keras based model with Keras fit but not distribution strategies.""" """Test keras based model with Keras fit but not distribution strategies."""
self._setup() self._setup()
flags.FLAGS.dist_strat_off = True flags.FLAGS.turn_off_distribution_strategy = True
flags.FLAGS.num_gpus = 1 flags.FLAGS.num_gpus = 1
flags.FLAGS.data_dir = DATA_DIR flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.batch_size = 128 flags.FLAGS.batch_size = 128
......
...@@ -39,7 +39,7 @@ def learning_rate_schedule(current_epoch, ...@@ -39,7 +39,7 @@ def learning_rate_schedule(current_epoch,
current_batch, current_batch,
batches_per_epoch, batches_per_epoch,
batch_size): batch_size):
"""Handles linear scaling rule, gradual warmup, and LR decay. """Handles linear scaling rule and LR decay.
Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the Scale learning rate at epoch boundaries provided in LR_SCHEDULE by the
provided scaling factor. provided scaling factor.
......
...@@ -180,12 +180,12 @@ def resnet56(classes=100, training=None): ...@@ -180,12 +180,12 @@ def resnet56(classes=100, training=None):
"""Instantiates the ResNet56 architecture. """Instantiates the ResNet56 architecture.
Arguments: Arguments:
classes: optional number of classes to classify images into classes: optional number of classes to classify images into
training: Only used if training keras model with Estimator. In other training: Only used if training keras model with Estimator. In other
scenarios it is handled automatically. scenarios it is handled automatically.
Returns: Returns:
A Keras model instance. A Keras model instance.
""" """
# Determine proper input shape # Determine proper input shape
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
......
...@@ -32,7 +32,7 @@ def get_distribution_strategy(num_gpus, ...@@ -32,7 +32,7 @@ def get_distribution_strategy(num_gpus,
See tf.contrib.distribute.AllReduceCrossDeviceOps for available See tf.contrib.distribute.AllReduceCrossDeviceOps for available
algorithms. If None, DistributionStrategy will choose based on device algorithms. If None, DistributionStrategy will choose based on device
topology. topology.
turn_off_distribution_strategy: when set to True, do not use any turn_off_distribution_strategy: when set to True, do not use any
distribution strategy. Note that when it is True, and num_gpus is distribution strategy. Note that when it is True, and num_gpus is
larger than 1, it will raise a ValueError. larger than 1, it will raise a ValueError.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment