"vscode:/vscode.git/clone" did not exist on "cc779bed95f2d764d4243d9d316d05bb1d749c88"
Commit c9c4981d authored by Xin Pan's avatar Xin Pan Committed by GitHub
Browse files

Merge pull request #586 from panyx0718/master

Update cifar input following data change.
parents 515bde38 6515a419
......@@ -71,12 +71,14 @@ curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binar
```shell
# cd to the your workspace.
# It contains an empty WORKSPACE file, resnet codes and cifar10 dataset.
# Note: User can split 5k from train set for eval set.
ls -R
.:
cifar10 resnet WORKSPACE
./cifar10:
test.bin train.bin validation.bin
data_batch_1.bin data_batch_2.bin data_batch_3.bin data_batch_4.bin
data_batch_5.bin test_batch.bin
./resnet:
BUILD cifar_input.py g3doc README.md resnet_main.py resnet_model.py
......@@ -85,7 +87,7 @@ ls -R
bazel build -c opt --config=cuda resnet/...
# Train the model.
bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \
bazel-bin/resnet/resnet_main --train_data_path=cifar10/data_batch* \
--log_root=/tmp/resnet_model \
--train_dir=/tmp/resnet_model/train \
--dataset='cifar10' \
......@@ -94,7 +96,7 @@ bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \
# Evaluate the model.
# Avoid running on the same GPU as the training job at the same time,
# otherwise, you might run out of memory.
bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test.bin \
bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test_batch.bin \
--log_root=/tmp/resnet_model \
--eval_dir=/tmp/resnet_model/test \
--mode=eval \
......
......@@ -49,7 +49,8 @@ def build_input(dataset, data_path, batch_size, mode):
image_bytes = image_size * image_size * depth
record_bytes = label_bytes + label_offset + image_bytes
file_queue = tf.train.string_input_producer([data_path], shuffle=True)
data_files = tf.gfile.Glob(data_path)
file_queue = tf.train.string_input_producer(data_files, shuffle=True)
# Read examples from files in the filename queue.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
_, value = reader.read(file_queue)
......
......@@ -26,8 +26,8 @@ import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
tf.app.flags.DEFINE_string('train_data_path', '', 'Filename for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '', 'Filename for eval data')
tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
tf.app.flags.DEFINE_string('train_dir', '',
'Directory to keep training outputs.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment