Commit c9c4981d authored by Xin Pan's avatar Xin Pan Committed by GitHub
Browse files

Merge pull request #586 from panyx0718/master

Update cifar input following data change.
parents 515bde38 6515a419
...@@ -71,12 +71,14 @@ curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binar ...@@ -71,12 +71,14 @@ curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binar
```shell ```shell
# cd to the your workspace. # cd to the your workspace.
# It contains an empty WORKSPACE file, resnet codes and cifar10 dataset. # It contains an empty WORKSPACE file, resnet codes and cifar10 dataset.
# Note: User can split 5k from train set for eval set.
ls -R ls -R
.: .:
cifar10 resnet WORKSPACE cifar10 resnet WORKSPACE
./cifar10: ./cifar10:
test.bin train.bin validation.bin data_batch_1.bin data_batch_2.bin data_batch_3.bin data_batch_4.bin
data_batch_5.bin test_batch.bin
./resnet: ./resnet:
BUILD cifar_input.py g3doc README.md resnet_main.py resnet_model.py BUILD cifar_input.py g3doc README.md resnet_main.py resnet_model.py
...@@ -85,7 +87,7 @@ ls -R ...@@ -85,7 +87,7 @@ ls -R
bazel build -c opt --config=cuda resnet/... bazel build -c opt --config=cuda resnet/...
# Train the model. # Train the model.
bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \ bazel-bin/resnet/resnet_main --train_data_path=cifar10/data_batch* \
--log_root=/tmp/resnet_model \ --log_root=/tmp/resnet_model \
--train_dir=/tmp/resnet_model/train \ --train_dir=/tmp/resnet_model/train \
--dataset='cifar10' \ --dataset='cifar10' \
...@@ -94,7 +96,7 @@ bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \ ...@@ -94,7 +96,7 @@ bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \
# Evaluate the model. # Evaluate the model.
# Avoid running on the same GPU as the training job at the same time, # Avoid running on the same GPU as the training job at the same time,
# otherwise, you might run out of memory. # otherwise, you might run out of memory.
bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test.bin \ bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test_batch.bin \
--log_root=/tmp/resnet_model \ --log_root=/tmp/resnet_model \
--eval_dir=/tmp/resnet_model/test \ --eval_dir=/tmp/resnet_model/test \
--mode=eval \ --mode=eval \
......
...@@ -49,7 +49,8 @@ def build_input(dataset, data_path, batch_size, mode): ...@@ -49,7 +49,8 @@ def build_input(dataset, data_path, batch_size, mode):
image_bytes = image_size * image_size * depth image_bytes = image_size * image_size * depth
record_bytes = label_bytes + label_offset + image_bytes record_bytes = label_bytes + label_offset + image_bytes
file_queue = tf.train.string_input_producer([data_path], shuffle=True) data_files = tf.gfile.Glob(data_path)
file_queue = tf.train.string_input_producer(data_files, shuffle=True)
# Read examples from files in the filename queue. # Read examples from files in the filename queue.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
_, value = reader.read(file_queue) _, value = reader.read(file_queue)
......
...@@ -26,8 +26,8 @@ import tensorflow as tf ...@@ -26,8 +26,8 @@ import tensorflow as tf
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.') tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.') tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
tf.app.flags.DEFINE_string('train_data_path', '', 'Filename for training data.') tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '', 'Filename for eval data') tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.') tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
tf.app.flags.DEFINE_string('train_dir', '', tf.app.flags.DEFINE_string('train_dir', '',
'Directory to keep training outputs.') 'Directory to keep training outputs.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment