Commit eaf2bd1b authored by Shining Sun's avatar Shining Sun
Browse files

fix the dataset flag

parent 84d3c62c
...@@ -181,7 +181,7 @@ def parse_record_keras(raw_record, is_training, dtype): ...@@ -181,7 +181,7 @@ def parse_record_keras(raw_record, is_training, dtype):
Returns: Returns:
Tuple with processed image tensor and one-hot-encoded label tensor. Tuple with processed image tensor and one-hot-encoded label tensor.
""" """
if shining.dataset == IMAGENET_DATASET: if flags_obj.dataset == IMAGENET_DATASET:
image_buffer, label, bbox = imagenet_main._parse_example_proto(raw_record) image_buffer, label, bbox = imagenet_main._parse_example_proto(raw_record)
image = imagenet_preprocessing.preprocess_image( image = imagenet_preprocessing.preprocess_image(
...@@ -194,11 +194,11 @@ def parse_record_keras(raw_record, is_training, dtype): ...@@ -194,11 +194,11 @@ def parse_record_keras(raw_record, is_training, dtype):
image = tf.cast(image, dtype) image = tf.cast(image, dtype)
label = tf.sparse_to_dense(label, (imagenet_main._NUM_CLASSES,), 1) label = tf.sparse_to_dense(label, (imagenet_main._NUM_CLASSES,), 1)
elif shining.dataset == CIFAR_DATASET: elif flags_obj.dataset == CIFAR_DATASET:
image, label = cifar_main.parse_record(raw_record, is_training, dtype) image, label = cifar_main.parse_record(raw_record, is_training, dtype)
label = tf.sparse_to_dense(label, (cifar_main._NUM_CLASSES,), 1) label = tf.sparse_to_dense(label, (cifar_main._NUM_CLASSES,), 1)
else: else:
raise ValueError("Unknown dataset: {%s}".format(shining.dataset)) raise ValueError("Unknown dataset: {%s}".format(flags_obj.dataset))
return image, label return image, label
...@@ -223,7 +223,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -223,7 +223,7 @@ def run_imagenet_with_keras(flags_obj):
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
train_input_dataset, eval_input_dataset = get_data( train_input_dataset, eval_input_dataset = get_data(
shining.dataset, flags_obj.use_synthetic_data) flags_obj.dataset, flags_obj.use_synthetic_data)
# Use Keras ResNet50 applications model and native keras APIs # Use Keras ResNet50 applications model and native keras APIs
# initialize RMSprop optimizer # initialize RMSprop optimizer
...@@ -240,7 +240,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -240,7 +240,7 @@ def run_imagenet_with_keras(flags_obj):
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
if shining.dataset == IMAGENET_DATASET: if flags_obj.dataset == IMAGENET_DATASET:
model = resnet_model_tpu.ResNet50(num_classes=imagenet_main._NUM_CLASSES) model = resnet_model_tpu.ResNet50(num_classes=imagenet_main._NUM_CLASSES)
steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size
...@@ -251,7 +251,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -251,7 +251,7 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] // num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
elif shining.dataset = CIFAR_DATASET: elif flags_obj.dataset = CIFAR_DATASET:
model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3), model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3),
include_top=True, include_top=True,
classes=cifar_main._NUM_CLASSES, classes=cifar_main._NUM_CLASSES,
...@@ -267,7 +267,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -267,7 +267,7 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
else: else:
raise ValueError("Unknown dataset: {%s}".format(shining.dataset)) raise ValueError("Unknown dataset: {%s}".format(flags_obj.dataset))
loss = 'categorical_crossentropy' loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy' accuracy = 'categorical_accuracy'
...@@ -418,9 +418,9 @@ if __name__ == '__main__': ...@@ -418,9 +418,9 @@ if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
define_keras_flags() define_keras_flags()
if shining.dataset == IMAGENET_DATASET: if flags_obj.dataset == IMAGENET_DATASET:
imagenet_main.define_imagenet_flags() imagenet_main.define_imagenet_flags()
elif shining.dataset == CIFAR_DATASET: elif flags_obj.dataset == CIFAR_DATASET:
cifar_main.define_cifar_flags() cifar_main.define_cifar_flags()
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment