Commit 98dd890b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

refactors some preprocessing code.

PiperOrigin-RevId: 310658964
parent 0fc994b6
......@@ -16,18 +16,21 @@
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
import pickle
import time
import timeit
# pylint: disable=wrong-import-order
from absl import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import typing
from typing import Dict, Text, Tuple
# pylint: enable=wrong-import-order
from official.recommendation import constants as rconst
......@@ -35,20 +38,15 @@ from official.recommendation import data_pipeline
from official.recommendation import movielens
DATASET_TO_NUM_USERS_AND_ITEMS = {
"ml-1m": (6040, 3706),
"ml-20m": (138493, 26744)
}
_EXPECTED_CACHE_KEYS = (
rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
def _filter_index_sort(raw_rating_path, cache_path):
# type: (str, str, bool) -> (dict, bool)
"""Read in data CSV, and output structured data.
def read_dataframe(
raw_rating_path: Text
) -> Tuple[Dict[int, int], Dict[int, int], pd.DataFrame]:
"""Read in data CSV, and output DataFrame for downstream processing.
This function reads in the raw CSV of positive items, and performs three
preprocessing transformations:
......@@ -63,43 +61,14 @@ def _filter_index_sort(raw_rating_path, cache_path):
This allows the dataframe to be sliced by user in-place, and for the last
item to be selected simply by calling the `-1` index of a user's slice.
While all of these transformations are performed by Pandas (and are therefore
single-threaded), they only take ~2 minutes, and the overhead to apply a
MapReduce pattern to parallel process the dataset adds significant complexity
for no computational gain. For a larger dataset parallelizing this
preprocessing could yield speedups. (Also, this preprocessing step is only
performed once for an entire run.
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved.
Returns:
A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
IDs to regularized user IDs, and a dict mapping raw item IDs to regularized
item IDs.
A dict mapping raw user IDs to regularized user IDs, a dict mapping raw
item IDs to regularized item IDs, and a filtered, zero-index remapped,
sorted dataframe.
"""
valid_cache = tf.io.gfile.exists(cache_path)
if valid_cache:
with tf.io.gfile.GFile(cache_path, "rb") as f:
cached_data = pickle.load(f)
# (nnigania)disabled this check as the dataset is not expected to change
# cache_age = time.time() - cached_data.get("create_time", 0)
# if cache_age > rconst.CACHE_INVALIDATION_SEC:
# valid_cache = False
for key in _EXPECTED_CACHE_KEYS:
if key not in cached_data:
valid_cache = False
if not valid_cache:
logging.info("Removing stale raw data cache file.")
tf.io.gfile.remove(cache_path)
if valid_cache:
data = cached_data
else:
with tf.io.gfile.GFile(raw_rating_path) as f:
df = pd.read_csv(f)
......@@ -142,10 +111,68 @@ def _filter_index_sort(raw_rating_path, cache_path):
# reference implementation.
df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
inplace=True, kind="mergesort")
inplace=True,
kind="mergesort")
# The dataframe does not reconstruct indices in the sort or filter steps.
df = df.reset_index()
return user_map, item_map, df.reset_index()
def _filter_index_sort(raw_rating_path: Text,
cache_path: Text) -> Tuple[pd.DataFrame, bool]:
"""Read in data CSV, and output structured data.
This function reads in the raw CSV of positive items, and performs three
preprocessing transformations:
1) Filter out all users who have not rated at least a certain number
of items. (Typically 20 items)
2) Zero index the users and items such that the largest user_id is
`num_users - 1` and the largest item_id is `num_items - 1`
3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
This allows the dataframe to be sliced by user in-place, and for the last
item to be selected simply by calling the `-1` index of a user's slice.
While all of these transformations are performed by Pandas (and are therefore
single-threaded), they only take ~2 minutes, and the overhead to apply a
MapReduce pattern to parallel process the dataset adds significant complexity
for no computational gain. For a larger dataset parallelizing this
preprocessing could yield speedups. (Also, this preprocessing step is only
performed once for an entire run.
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
cache_path: The path to the file where results of this function are saved.
Returns:
A filtered, zero-index remapped, sorted dataframe, a dict mapping raw user
IDs to regularized user IDs, and a dict mapping raw item IDs to regularized
item IDs.
"""
valid_cache = tf.io.gfile.exists(cache_path)
if valid_cache:
with tf.io.gfile.GFile(cache_path, "rb") as f:
cached_data = pickle.load(f)
# (nnigania)disabled this check as the dataset is not expected to change
# cache_age = time.time() - cached_data.get("create_time", 0)
# if cache_age > rconst.CACHE_INVALIDATION_SEC:
# valid_cache = False
for key in _EXPECTED_CACHE_KEYS:
if key not in cached_data:
valid_cache = False
if not valid_cache:
logging.info("Removing stale raw data cache file.")
tf.io.gfile.remove(cache_path)
if valid_cache:
data = cached_data
else:
user_map, item_map, df = read_dataframe(raw_rating_path)
grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
......@@ -201,7 +228,7 @@ def instantiate_pipeline(dataset,
raw_data, _ = _filter_index_sort(raw_rating_path, cache_path)
user_map, item_map = raw_data["user_map"], raw_data["item_map"]
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
num_users, num_items = movielens.DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
if num_users != len(user_map):
raise ValueError("Expected to find {} users, but found {}".format(
......
......@@ -95,8 +95,7 @@ class BaseTest(tf.test.TestCase):
movielens.download = mock_download
movielens.NUM_RATINGS[DATASET] = NUM_PTS
data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS,
NUM_ITEMS)
movielens.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS, NUM_ITEMS)
def make_params(self, train_epochs=1):
return {
......
......@@ -84,6 +84,8 @@ NUM_RATINGS = {
ML_20M: 20000263
}
DATASET_TO_NUM_USERS_AND_ITEMS = {ML_1M: (6040, 3706), ML_20M: (138493, 26744)}
def _download_and_clean(dataset, data_dir):
"""Download MovieLens dataset in a standard format.
......@@ -284,17 +286,24 @@ def integerize_genres(dataframe):
return dataframe
def define_flags():
"""Add flags specifying data usage arguments."""
flags.DEFINE_enum(
name="dataset",
default=None,
enum_values=DATASETS,
case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
def define_data_download_flags():
"""Add flags specifying data download arguments."""
"""Add flags specifying data download and usage arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/movielens-data/",
help=flags_core.help_wrap(
"Directory to download and extract data."))
flags.DEFINE_enum(
name="dataset", default=None,
enum_values=DATASETS, case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
define_flags()
def main(_):
......
......@@ -50,7 +50,7 @@ def get_inputs(params):
if FLAGS.use_synthetic_data:
producer = data_pipeline.DummyConstructor()
num_users, num_items = data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[
num_users, num_items = movielens.DATASET_TO_NUM_USERS_AND_ITEMS[
FLAGS.dataset]
num_train_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
......@@ -163,21 +163,18 @@ def define_ncf_flags():
flags.adopt_module_key_flags(flags_core)
movielens.define_flags()
flags_core.set_defaults(
model_dir="/tmp/ncf/",
data_dir="/tmp/movielens-data/",
dataset=movielens.ML_1M,
train_epochs=2,
batch_size=99000,
tpu=None
)
# Add ncf-specific flags
flags.DEFINE_enum(
name="dataset", default="ml-1m",
enum_values=["ml-1m", "ml-20m"], case_sensitive=False,
help=flags_core.help_wrap(
"Dataset to be trained and evaluated."))
flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present."))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment