Commit eaf2bd1b authored by Shining Sun's avatar Shining Sun
Browse files

fix the dataset flag

parent 84d3c62c
......@@ -181,7 +181,7 @@ def parse_record_keras(raw_record, is_training, dtype):
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
if shining.dataset == IMAGENET_DATASET:
if flags_obj.dataset == IMAGENET_DATASET:
image_buffer, label, bbox = imagenet_main._parse_example_proto(raw_record)
image = imagenet_preprocessing.preprocess_image(
......@@ -194,11 +194,11 @@ def parse_record_keras(raw_record, is_training, dtype):
image = tf.cast(image, dtype)
label = tf.sparse_to_dense(label, (imagenet_main._NUM_CLASSES,), 1)
elif shining.dataset == CIFAR_DATASET:
elif flags_obj.dataset == CIFAR_DATASET:
image, label = cifar_main.parse_record(raw_record, is_training, dtype)
label = tf.sparse_to_dense(label, (cifar_main._NUM_CLASSES,), 1)
else:
raise ValueError("Unknown dataset: {%s}".format(shining.dataset))
raise ValueError("Unknown dataset: {%s}".format(flags_obj.dataset))
return image, label
......@@ -223,7 +223,7 @@ def run_imagenet_with_keras(flags_obj):
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
train_input_dataset, eval_input_dataset = get_data(
shining.dataset, flags_obj.use_synthetic_data)
flags_obj.dataset, flags_obj.use_synthetic_data)
# Use Keras ResNet50 applications model and native keras APIs
# initialize RMSprop optimizer
......@@ -240,7 +240,7 @@ def run_imagenet_with_keras(flags_obj):
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus)
if shining.dataset == IMAGENET_DATASET:
if flags_obj.dataset == IMAGENET_DATASET:
model = resnet_model_tpu.ResNet50(num_classes=imagenet_main._NUM_CLASSES)
steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size
......@@ -251,7 +251,7 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
flags_obj.batch_size)
elif shining.dataset = CIFAR_DATASET:
elif flags_obj.dataset = CIFAR_DATASET:
model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3),
include_top=True,
classes=cifar_main._NUM_CLASSES,
......@@ -267,7 +267,7 @@ def run_imagenet_with_keras(flags_obj):
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
flags_obj.batch_size)
else:
raise ValueError("Unknown dataset: {%s}".format(shining.dataset))
raise ValueError("Unknown dataset: {%s}".format(flags_obj.dataset))
loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy'
......@@ -418,9 +418,9 @@ if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
define_keras_flags()
if shining.dataset == IMAGENET_DATASET:
if flags_obj.dataset == IMAGENET_DATASET:
imagenet_main.define_imagenet_flags()
elif shining.dataset == CIFAR_DATASET:
elif flags_obj.dataset == CIFAR_DATASET:
cifar_main.define_cifar_flags()
absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment