data_preprocessing.py 9.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocess dataset and construct any necessary artifacts."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import atexit
import contextlib
import gc
24
import hashlib
25
26
27
28
import json
import os
import pickle
import signal
29
import socket
30
import subprocess
31
import threading
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import time
import timeit
import typing

# pylint: disable=wrong-import-order
from absl import app as absl_app
from absl import flags
import numpy as np
import pandas as pd
import six
import tensorflow as tf
# pylint: enable=wrong-import-order

from official.datasets import movielens
from official.recommendation import constants as rconst
47
from official.recommendation import data_pipeline
48
from official.recommendation import stat_utils
49
from official.utils.logs import mlperf_helper
50
51


52
53
54
55
56
57
DATASET_TO_NUM_USERS_AND_ITEMS = {
    "ml-1m": (6040, 3706),
    "ml-20m": (138493, 26744)
}


58
59
60
_EXPECTED_CACHE_KEYS = (
    rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
    rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP, "match_mlperf")
61
62


63
64
def _filter_index_sort(raw_rating_path, cache_path, match_mlperf):
  # type: (str, str, bool) -> (dict, bool)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  """Read in data CSV, and output structured data.

  This function reads in the raw CSV of positive items, and performs three
  preprocessing transformations:

  1)  Filter out all users who have not rated at least a certain number
      of items. (Typically 20 items)

  2)  Zero index the users and items such that the largest user_id is
      `num_users - 1` and the largest item_id is `num_items - 1`

  3)  Sort the dataframe by user_id, with timestamp as a secondary sort key.
      This allows the dataframe to be sliced by user in-place, and for the last
      item to be selected simply by calling the `-1` index of a user's slice.

  While all of these transformations are performed by Pandas (and are therefore
  single-threaded), they only take ~2 minutes, and the overhead to apply a
  MapReduce pattern to parallel process the dataset adds significant complexity
  for no computational gain. For a larger dataset parallelizing this
  preprocessing could yield speedups. (Also, this preprocessing step is only
  performed once for an entire run.

  Args:
    raw_rating_path: The path to the CSV which contains the raw dataset.
89
    cache_path: The path to the file where results of this function are saved.
90
91
    match_mlperf: If True, change the sorting algorithm to match the MLPerf
      reference implementation.
92
93

  Returns:
Reed's avatar
Reed committed
94
95
96
    A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
    IDs to regularized user IDs, and a dict mapping raw item IDs to regularized
    item IDs.
97
  """
98
99
100
101
  valid_cache = tf.gfile.Exists(cache_path)
  if valid_cache:
    with tf.gfile.Open(cache_path, "rb") as f:
      cached_data = pickle.load(f)
102

103
104
105
    cache_age = time.time() - cached_data.get("create_time", 0)
    if cache_age > rconst.CACHE_INVALIDATION_SEC:
      valid_cache = False
106

107
108
    if cached_data["match_mlperf"] != match_mlperf:
      valid_cache = False
109

110
111
112
    for key in _EXPECTED_CACHE_KEYS:
      if key not in cached_data:
        valid_cache = False
113

114
115
116
    if not valid_cache:
      tf.logging.info("Removing stale raw data cache file.")
      tf.gfile.Remove(cache_path)
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
  if valid_cache:
    data = cached_data
  else:
    with tf.gfile.Open(raw_rating_path) as f:
      df = pd.read_csv(f)

    # Get the info of users who have more than 20 ratings on items
    grouped = df.groupby(movielens.USER_COLUMN)
    df = grouped.filter(
        lambda x: len(x) >= rconst.MIN_NUM_RATINGS) # type: pd.DataFrame

    original_users = df[movielens.USER_COLUMN].unique()
    original_items = df[movielens.ITEM_COLUMN].unique()

    # Map the ids of user and item to 0 based index for following processing
    tf.logging.info("Generating user_map and item_map...")
    user_map = {user: index for index, user in enumerate(original_users)}
    item_map = {item: index for index, item in enumerate(original_items)}

    df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
        lambda user: user_map[user])
    df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
        lambda item: item_map[item])

    num_users = len(original_users)
    num_items = len(original_items)

    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.PREPROC_HP_NUM_EVAL,
                            value=rconst.NUM_EVAL_NEGATIVES)
    mlperf_helper.ncf_print(
        key=mlperf_helper.TAGS.PREPROC_HP_SAMPLE_EVAL_REPLACEMENT,
        value=match_mlperf)

    assert num_users <= np.iinfo(rconst.USER_DTYPE).max
    assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
    assert df[movielens.USER_COLUMN].max() == num_users - 1
    assert df[movielens.ITEM_COLUMN].max() == num_items - 1

    # This sort is used to shard the dataframe by user, and later to select
    # the last item for a user to be used in validation.
    tf.logging.info("Sorting by user, timestamp...")

    # This sort is equivalent to
    #   df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
    #   inplace=True)
    # except that the order of items with the same user and timestamp are
    # sometimes different. For some reason, this sort results in a better
    # hit-rate during evaluation, matching the performance of the MLPerf
    # reference implementation.
    df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
    df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
                   inplace=True, kind="mergesort")
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    df = df.reset_index()  # The dataframe does not reconstruct indices in the
                           # sort or filter steps.

    grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
    eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])

    data = {
        rconst.TRAIN_USER_KEY: train_df[movielens.USER_COLUMN]
                               .values.astype(rconst.USER_DTYPE),
        rconst.TRAIN_ITEM_KEY: train_df[movielens.ITEM_COLUMN]
                               .values.astype(rconst.ITEM_DTYPE),
        rconst.EVAL_USER_KEY: eval_df[movielens.USER_COLUMN]
                              .values.astype(rconst.USER_DTYPE),
        rconst.EVAL_ITEM_KEY: eval_df[movielens.ITEM_COLUMN]
                              .values.astype(rconst.ITEM_DTYPE),
        rconst.USER_MAP: user_map,
        rconst.ITEM_MAP: item_map,
        "create_time": time.time(),
        "match_mlperf": match_mlperf,
    }

    tf.logging.info("Writing raw data cache.")
    with tf.gfile.Open(cache_path, "wb") as f:
      pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

  # TODO(robieta): MLPerf cache clear.
  return data, valid_cache


200
201
202
def instantiate_pipeline(dataset, data_dir, params, constructor_type=None,
                         deterministic=False):
  # type: (str, str, dict, typing.Optional[str], bool) -> (NCFDataset, typing.Callable)
203
204
205
206
207
  """Load and digest data CSV into a usable form.

  Args:
    dataset: The name of the dataset to be used.
    data_dir: The root directory of the dataset.
Taylor Robie's avatar
Taylor Robie committed
208
    params: dict of parameters for the run.
209
210
211
    constructor_type: The name of the constructor subclass that should be used
      for the input pipeline.
    deterministic: Tell the data constructor to produce deterministically.
212
  """
213
  tf.logging.info("Beginning data preprocessing.")
214
215
216

  st = timeit.default_timer()
  raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
217
218
219
220
221
  cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)

  raw_data, _ = _filter_index_sort(raw_rating_path, cache_path,
                                   params["match_mlperf"])
  user_map, item_map = raw_data["user_map"], raw_data["item_map"]
222
223
224
225
226
227
228
229
  num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]

  if num_users != len(user_map):
    raise ValueError("Expected to find {} users, but found {}".format(
        num_users, len(user_map)))
  if num_items != len(item_map):
    raise ValueError("Expected to find {} items, but found {}".format(
        num_items, len(item_map)))
230

231
  producer = data_pipeline.get_constructor(constructor_type or "materialized")(
232
233
234
235
236
237
238
239
240
241
242
243
244
245
      maximum_number_epochs=params["train_epochs"],
      num_users=num_users,
      num_items=num_items,
      user_map=user_map,
      item_map=item_map,
      train_pos_users=raw_data[rconst.TRAIN_USER_KEY],
      train_pos_items=raw_data[rconst.TRAIN_ITEM_KEY],
      train_batch_size=params["batch_size"],
      batches_per_train_step=params["batches_per_step"],
      num_train_negatives=params["num_neg"],
      eval_pos_users=raw_data[rconst.EVAL_USER_KEY],
      eval_pos_items=raw_data[rconst.EVAL_ITEM_KEY],
      eval_batch_size=params["eval_batch_size"],
      batches_per_eval_step=params["batches_per_step"],
246
247
      stream_files=params["use_tpu"],
      deterministic=deterministic
248
  )
249
250

  run_time = timeit.default_timer() - st
251
  tf.logging.info("Data preprocessing complete. Time: {:.1f} sec."
252
253
                  .format(run_time))

254
255
  print(producer)
  return num_users, num_items, producer