cifar10_download_and_extract.py 2.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

16
17
"""Downloads and extracts the binary version of the CIFAR-10 dataset."""

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys
import tarfile

from six.moves import urllib
import tensorflow as tf

DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'

parser = argparse.ArgumentParser()

parser.add_argument(
    '--data_dir', type=str, default='/tmp/cifar10_data',
    help='Directory to download data and extract the tarball')


Karmel Allison's avatar
Karmel Allison committed
39
def main(_):
40
41
42
43
44
45
46
47
48
  """Download and extract the tarball from Alex's website."""
  if not os.path.exists(FLAGS.data_dir):
    os.makedirs(FLAGS.data_dir)

  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(FLAGS.data_dir, filename)

  if not os.path.exists(filepath):
    def _progress(count, block_size, total_size):
49
50
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, 100.0 * count * block_size / total_size))
51
52
53
54
55
56
57
58
59
60
61
      sys.stdout.flush()

    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')

  tarfile.open(filepath, 'r:gz').extractall(FLAGS.data_dir)


if __name__ == '__main__':
62
63
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(argv=[sys.argv[0]] + unparsed)