data_download.py 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""Downloads the UCI HIGGS Dataset and prepares train data.

The details on the dataset are in https://archive.ics.uci.edu/ml/datasets/HIGGS

It takes a while as it needs to download 2.8 GB over the network, process, then
store it into the specified location as a compressed numpy file.

Usage:
$ python data_download.py --data_dir=/tmp/higgs_data
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

15
import gzip
16
17
18
import os
import tempfile

19
# pylint: disable=g-bad-import-order
20
21
22
import numpy as np
import pandas as pd
from six.moves import urllib
23
24
from absl import app as absl_app
from absl import flags
25
26
import tensorflow as tf

27
from official.utils.flags import core as flags_core
28

29
30
31
URL_ROOT = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280"
INPUT_FILE = "HIGGS.csv.gz"
NPZ_FILE = "HIGGS.csv.gz.npz"  # numpy compressed file to contain "data" array.
32
33
34
35
36
37
38


def _download_higgs_data_and_save_npz(data_dir):
  """Download higgs data and store as a numpy compressed file."""
  input_url = os.path.join(URL_ROOT, INPUT_FILE)
  np_filename = os.path.join(data_dir, NPZ_FILE)
  if tf.gfile.Exists(np_filename):
39
    raise ValueError("data_dir already has the processed data file: {}".format(
40
41
42
43
44
        np_filename))
  if not tf.gfile.Exists(data_dir):
    tf.gfile.MkDir(data_dir)
  # 2.8 GB to download.
  try:
45
    tf.logging.info("Data downloading...")
46
47
    temp_filename, _ = urllib.request.urlretrieve(input_url)
    # Reading and parsing 11 million csv lines takes 2~3 minutes.
48
49
50
51
52
53
54
    tf.logging.info("Data processing... taking multiple minutes...")
    with gzip.open(temp_filename, "rb") as csv_file:
      data = pd.read_csv(
          csv_file,
          dtype=np.float32,
          names=["c%02d" % i for i in range(29)]  # label + 28 features.
      ).as_matrix()
55
  finally:
56
    tf.gfile.Remove(temp_filename)
57
58
59
60
61

  # Writing to temporary location then copy to the data_dir (0.8 GB).
  f = tempfile.NamedTemporaryFile()
  np.savez_compressed(f, data=data)
  tf.gfile.Copy(f.name, np_filename)
62
  tf.logging.info("Data saved to: {}".format(np_filename))
63
64
65
66
67
68
69
70


def main(unused_argv):
  if not tf.gfile.Exists(FLAGS.data_dir):
    tf.gfile.MkDir(FLAGS.data_dir)
  _download_higgs_data_and_save_npz(FLAGS.data_dir)


71
72
73
74
75
76
77
78
79
80
81
82
83
def define_data_download_flags():
  """Add flags specifying data download arguments."""
  flags.DEFINE_string(
      name="data_dir", default="/tmp/higgs_data",
      help=flags_core.help_wrap(
          "Directory to download higgs dataset and store training/eval data."))


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  define_data_download_flags()
  FLAGS = flags.FLAGS
  absl_app.run(main)