Commit 98dd890b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

refactors some preprocessing code.

PiperOrigin-RevId: 310658964
parent 0fc994b6
......@@ -16,18 +16,21 @@
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
import pickle
import time
import timeit
# pylint: disable=wrong-import-order
from absl import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import typing
from typing import Dict, Text, Tuple
# pylint: enable=wrong-import-order
from official.recommendation import constants as rconst
......@@ -35,19 +38,88 @@ from official.recommendation import data_pipeline
from official.recommendation import movielens
DATASET_TO_NUM_USERS_AND_ITEMS = {
"ml-1m": (6040, 3706),
"ml-20m": (138493, 26744)
}
_EXPECTED_CACHE_KEYS = (
rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
def _filter_index_sort(raw_rating_path, cache_path):
# type: (str, str, bool) -> (dict, bool)
def read_dataframe(
raw_rating_path: Text
) -> Tuple[Dict[int, int], Dict[int, int], pd.DataFrame]:
"""Read in data CSV, and output DataFrame for downstream processing.
This function reads in the raw CSV of positive items, and performs three
preprocessing transformations:
1) Filter out all users who have not rated at least a certain number
of items. (Typically 20 items)
2) Zero index the users and items such that the largest user_id is
`num_users - 1` and the largest item_id is `num_items - 1`
3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
This allows the dataframe to be sliced by user in-place, and for the last
item to be selected simply by calling the `-1` index of a user's slice.
Args:
raw_rating_path: The path to the CSV which contains the raw dataset.
Returns:
A dict mapping raw user IDs to regularized user IDs, a dict mapping raw
item IDs to regularized item IDs, and a filtered, zero-index remapped,
sorted dataframe.
"""
with tf.io.gfile.GFile(raw_rating_path) as f:
df = pd.read_csv(f)
# Get the info of users who have more than 20 ratings on items
grouped = df.groupby(movielens.USER_COLUMN)
df = grouped.filter(
lambda x: len(x) >= rconst.MIN_NUM_RATINGS) # type: pd.DataFrame
original_users = df[movielens.USER_COLUMN].unique()
original_items = df[movielens.ITEM_COLUMN].unique()
# Map the ids of user and item to 0 based index for following processing
logging.info("Generating user_map and item_map...")
user_map = {user: index for index, user in enumerate(original_users)}
item_map = {item: index for index, item in enumerate(original_items)}
df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
lambda user: user_map[user])
df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
lambda item: item_map[item])
num_users = len(original_users)
num_items = len(original_items)
assert num_users <= np.iinfo(rconst.USER_DTYPE).max
assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
assert df[movielens.USER_COLUMN].max() == num_users - 1
assert df[movielens.ITEM_COLUMN].max() == num_items - 1
# This sort is used to shard the dataframe by user, and later to select
# the last item for a user to be used in validation.
logging.info("Sorting by user, timestamp...")
# This sort is equivalent to
# df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
# inplace=True)
# except that the order of items with the same user and timestamp are
# sometimes different. For some reason, this sort results in a better
# hit-rate during evaluation, matching the performance of the MLPerf
# reference implementation.
df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
inplace=True,
kind="mergesort")
# The dataframe does not reconstruct indices in the sort or filter steps.
return user_map, item_map, df.reset_index()
def _filter_index_sort(raw_rating_path: Text,
cache_path: Text) -> Tuple[pd.DataFrame, bool]:
"""Read in data CSV, and output structured data.
This function reads in the raw CSV of positive items, and performs three
......@@ -100,52 +172,7 @@ def _filter_index_sort(raw_rating_path, cache_path):
if valid_cache:
data = cached_data
else:
with tf.io.gfile.GFile(raw_rating_path) as f:
df = pd.read_csv(f)
# Get the info of users who have more than 20 ratings on items
grouped = df.groupby(movielens.USER_COLUMN)
df = grouped.filter(
lambda x: len(x) >= rconst.MIN_NUM_RATINGS) # type: pd.DataFrame
original_users = df[movielens.USER_COLUMN].unique()
original_items = df[movielens.ITEM_COLUMN].unique()
# Map the ids of user and item to 0 based index for following processing
logging.info("Generating user_map and item_map...")
user_map = {user: index for index, user in enumerate(original_users)}
item_map = {item: index for index, item in enumerate(original_items)}
df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
lambda user: user_map[user])
df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
lambda item: item_map[item])
num_users = len(original_users)
num_items = len(original_items)
assert num_users <= np.iinfo(rconst.USER_DTYPE).max
assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
assert df[movielens.USER_COLUMN].max() == num_users - 1
assert df[movielens.ITEM_COLUMN].max() == num_items - 1
# This sort is used to shard the dataframe by user, and later to select
# the last item for a user to be used in validation.
logging.info("Sorting by user, timestamp...")
# This sort is equivalent to
# df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
# inplace=True)
# except that the order of items with the same user and timestamp are
# sometimes different. For some reason, this sort results in a better
# hit-rate during evaluation, matching the performance of the MLPerf
# reference implementation.
df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
inplace=True, kind="mergesort")
# The dataframe does not reconstruct indices in the sort or filter steps.
df = df.reset_index()
user_map, item_map, df = read_dataframe(raw_rating_path)
grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
......@@ -201,7 +228,7 @@ def instantiate_pipeline(dataset,
raw_data, _ = _filter_index_sort(raw_rating_path, cache_path)
user_map, item_map = raw_data["user_map"], raw_data["item_map"]
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
num_users, num_items = movielens.DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
if num_users != len(user_map):
raise ValueError("Expected to find {} users, but found {}".format(
......
......@@ -95,8 +95,7 @@ class BaseTest(tf.test.TestCase):
movielens.download = mock_download
movielens.NUM_RATINGS[DATASET] = NUM_PTS
data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS,
NUM_ITEMS)
movielens.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS, NUM_ITEMS)
def make_params(self, train_epochs=1):
return {
......
......@@ -84,6 +84,8 @@ NUM_RATINGS = {
ML_20M: 20000263
}
DATASET_TO_NUM_USERS_AND_ITEMS = {ML_1M: (6040, 3706), ML_20M: (138493, 26744)}
def _download_and_clean(dataset, data_dir):
"""Download MovieLens dataset in a standard format.
......@@ -284,17 +286,24 @@ def integerize_genres(dataframe):
return dataframe
def define_flags():
"""Add flags specifying data usage arguments."""
flags.DEFINE_enum(
name="dataset",
default=None,
enum_values=DATASETS,
case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
def define_data_download_flags():
"""Add flags specifying data download arguments."""
"""Add flags specifying data download and usage arguments."""
flags.DEFINE_string(
name="data_dir", default="/tmp/movielens-data/",
help=flags_core.help_wrap(
"Directory to download and extract data."))
flags.DEFINE_enum(
name="dataset", default=None,
enum_values=DATASETS, case_sensitive=False,
help=flags_core.help_wrap("Dataset to be trained and evaluated."))
define_flags()
def main(_):
......
......@@ -50,7 +50,7 @@ def get_inputs(params):
if FLAGS.use_synthetic_data:
producer = data_pipeline.DummyConstructor()
num_users, num_items = data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[
num_users, num_items = movielens.DATASET_TO_NUM_USERS_AND_ITEMS[
FLAGS.dataset]
num_train_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
num_eval_steps = rconst.SYNTHETIC_BATCHES_PER_EPOCH
......@@ -163,21 +163,18 @@ def define_ncf_flags():
flags.adopt_module_key_flags(flags_core)
movielens.define_flags()
flags_core.set_defaults(
model_dir="/tmp/ncf/",
data_dir="/tmp/movielens-data/",
dataset=movielens.ML_1M,
train_epochs=2,
batch_size=99000,
tpu=None
)
# Add ncf-specific flags
flags.DEFINE_enum(
name="dataset", default="ml-1m",
enum_values=["ml-1m", "ml-20m"], case_sensitive=False,
help=flags_core.help_wrap(
"Dataset to be trained and evaluated."))
flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present."))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment