Commit c868da8b authored by Shining Sun's avatar Shining Sun
Browse files

bug fixes

parent 25efe03e
...@@ -136,8 +136,7 @@ def run(flags_obj): ...@@ -136,8 +136,7 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy) flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_cifar_model.resnet56(input_shape=(32, 32, 3), model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
......
...@@ -163,7 +163,7 @@ def define_keras_flags(): ...@@ -163,7 +163,7 @@ def define_keras_flags():
flags.DEFINE_integer( flags.DEFINE_integer(
name='train_steps', default=None, name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than ' help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # bathes per epoch. When this flag is ' '# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.') 'set, only one epoch is going to run for training.')
......
...@@ -33,42 +33,6 @@ BATCH_NORM_EPSILON = 1e-5 ...@@ -33,42 +33,6 @@ BATCH_NORM_EPSILON = 1e-5
L2_WEIGHT_DECAY = 2e-4 L2_WEIGHT_DECAY = 2e-4
def _obtain_input_shape(input_shape,
default_size,
data_format):
"""Internal utility to compute/validate a model's input shape.
Arguments:
input_shape: Either None (will return the default network input shape),
or a user-provided shape to be validated.
default_size: Default input width/height for the model.
data_format: Image data format to use.
Returns:
An integer shape tuple (may include None entries).
Raises:
ValueError: In case of invalid argument values.
"""
if input_shape and len(input_shape) == 3:
if data_format == 'channels_first':
if input_shape[0] not in {1, 3}:
warnings.warn(
'This model usually expects 1 or 3 input channels. '
'However, it was passed an input_shape with ' +
str(input_shape[0]) + ' input channels.')
default_shape = (input_shape[0], default_size, default_size)
else:
if input_shape[-1] not in {1, 3}:
warnings.warn(
'This model usually expects 1 or 3 input channels. '
'However, it was passed an input_shape with ' +
str(input_shape[-1]) + ' input channels.')
default_shape = (default_size, default_size, input_shape[-1])
return input_shape
def identity_building_block(input_tensor, def identity_building_block(input_tensor,
kernel_size, kernel_size,
filters, filters,
...@@ -212,7 +176,7 @@ def conv_building_block(input_tensor, ...@@ -212,7 +176,7 @@ def conv_building_block(input_tensor,
return x return x
def resnet56(input_shape=None, classes=100, training=None): def resnet56(classes=100, training=None):
"""Instantiates the ResNet56 architecture. """Instantiates the ResNet56 architecture.
Arguments: Arguments:
...@@ -225,16 +189,12 @@ def resnet56(input_shape=None, classes=100, training=None): ...@@ -225,16 +189,12 @@ def resnet56(input_shape=None, classes=100, training=None):
A Keras model instance. A Keras model instance.
""" """
# Determine proper input shape # Determine proper input shape
input_shape = _obtain_input_shape( if backend.image_data_format() == 'channels_first':
input_shape, input_shape = (3, 32, 32)
default_size=32,
data_format=tf.keras.backend.image_data_format())
img_input = tf.keras.layers.Input(shape=input_shape)
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1 bn_axis = 1
else: # channel_last
input_shape = (32, 32, 3)
bn_axis = 3
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input) x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.Conv2D(16, (3, 3), x = tf.keras.layers.Conv2D(16, (3, 3),
......
...@@ -181,6 +181,7 @@ def conv_block(input_tensor, ...@@ -181,6 +181,7 @@ def conv_block(input_tensor,
def resnet50(num_classes): def resnet50(num_classes):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment