"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "57f9f9dcf94777ad0ac26c78287c10edcb1ff264"
Commit c868da8b authored by Shining Sun's avatar Shining Sun
Browse files

bug fixes

parent 25efe03e
......@@ -136,8 +136,7 @@ def run(flags_obj):
strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.turn_off_distribution_strategy)
model = resnet_cifar_model.resnet56(input_shape=(32, 32, 3),
classes=cifar_main.NUM_CLASSES)
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
......
......@@ -163,7 +163,7 @@ def define_keras_flags():
flags.DEFINE_integer(
name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # bathes per epoch. When this flag is '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
......
......@@ -33,42 +33,6 @@ BATCH_NORM_EPSILON = 1e-5
L2_WEIGHT_DECAY = 2e-4
def _obtain_input_shape(input_shape,
default_size,
data_format):
"""Internal utility to compute/validate a model's input shape.
Arguments:
input_shape: Either None (will return the default network input shape),
or a user-provided shape to be validated.
default_size: Default input width/height for the model.
data_format: Image data format to use.
Returns:
An integer shape tuple (may include None entries).
Raises:
ValueError: In case of invalid argument values.
"""
if input_shape and len(input_shape) == 3:
if data_format == 'channels_first':
if input_shape[0] not in {1, 3}:
warnings.warn(
'This model usually expects 1 or 3 input channels. '
'However, it was passed an input_shape with ' +
str(input_shape[0]) + ' input channels.')
default_shape = (input_shape[0], default_size, default_size)
else:
if input_shape[-1] not in {1, 3}:
warnings.warn(
'This model usually expects 1 or 3 input channels. '
'However, it was passed an input_shape with ' +
str(input_shape[-1]) + ' input channels.')
default_shape = (default_size, default_size, input_shape[-1])
return input_shape
def identity_building_block(input_tensor,
kernel_size,
filters,
......@@ -212,7 +176,7 @@ def conv_building_block(input_tensor,
return x
def resnet56(input_shape=None, classes=100, training=None):
def resnet56(classes=100, training=None):
"""Instantiates the ResNet56 architecture.
Arguments:
......@@ -225,16 +189,12 @@ def resnet56(input_shape=None, classes=100, training=None):
A Keras model instance.
"""
# Determine proper input shape
input_shape = _obtain_input_shape(
input_shape,
default_size=32,
data_format=tf.keras.backend.image_data_format())
img_input = tf.keras.layers.Input(shape=input_shape)
if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32)
bn_axis = 1
else: # channel_last
input_shape = (32, 32, 3)
bn_axis = 3
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.Conv2D(16, (3, 3),
......
......@@ -181,6 +181,7 @@ def conv_block(input_tensor,
def resnet50(num_classes):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment