Unverified Commit 81d77669 authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Add recommendation model (#4175)

* Add recommendation model

* Fix pylints check error

* Rename file

* Address comments, update input pipeline, and add distribution strategy

* Fix import error

* Address more comments

* Fix lints
parent 8507934f
# Recommendation Model
## Overview
This is an implementation of the Neural Collaborative Filtering (NCF) framework with Neural Matrix Factorization (NeuMF) model as described in the [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) paper. Current implementation is based on the code from the authors' [NCF code](https://github.com/hexiangnan/neural_collaborative_filtering) and the Stanford implementation in the [MLPerf Repo](https://github.com/mlperf/reference/tree/master/recommendation/pytorch).
NCF is a general framework for collaborative filtering of recommendations in which a neural network architecture is used to model user-item interactions. Unlike traditional models, NCF does not resort to Matrix Factorization (MF) with an inner product on latent features of users and items. It replaces the inner product with a multi-layer perceptron that can learn an arbitrary function from data.
Two instantiations of NCF are Generalized Matrix Factorization (GMF) and Multi-Layer Perceptron (MLP). GMF applies a linear kernel to model the latent feature interactions, and and MLP uses a nonlinear kernel to learn the interaction function from data. NeuMF is a fused model of GMF and MLP to better model the complex user-item interactions, and unifies the strengths of linearity of MF and non-linearity of MLP for modeling the user-item latent structures. NeuMF allows GMF and MLP to learn separate embeddings, and combines the two models by concatenating their last hidden layer. [neumf_model.py](neumf_model.py) defines the architecture details.
Some abbreviations used the code base include:
- NCF: Neural Collaborative Filtering
- NeuMF: Neural Matrix Factorization
- GMF: Generalized Matrix Factorization
- MLP: Multi-Layer Perceptron
- HR: Hit Ratio (HR)
- NDCG: Normalized Discounted Cumulative Gain
- ml-1m: MovieLens 1 million dataset
- ml-20m: MovieLens 20 million dataset
## Dataset
The [MovieLens datasets](http://files.grouplens.org/datasets/movielens/) are used for model training and evaluation. Specifically, we use two datasets: **ml-1m** (short for MovieLens 1 million) and **ml-20m** (short for MovieLens 20 million).
### ml-1m
ml-1m dataset contains 1,000,209 anonymous ratings of approximately 3,706 movies made by 6,040 users who joined MovieLens in 2000. All ratings are contained in the file "ratings.dat" without header row, and are in the following format:
```
UserID::MovieID::Rating::Timestamp
```
- UserIDs range between 1 and 6040.
- MovieIDs range between 1 and 3952.
- Ratings are made on a 5-star scale (whole-star ratings only).
### ml-20m
ml-20m dataset contains 20,000,263 ratings of 26,744 movies by 138493 users. All ratings are contained in the file "ratings.csv". Each line of this file after the header row represents one rating of one movie by one user, and has the following format:
```
userId,movieId,rating,timestamp
```
- The lines within this file are ordered first by userId, then, within user, by movieId.
- Ratings are made on a 5-star scale, with half-star increments (0.5 stars - 5.0 stars).
In both datasets, the timestamp is represented in seconds since midnight Coordinated Universal Time (UTC) of January 1, 1970. Each user has at least 20 ratings.
## Running Code
### Download and preprocess dataset
To download the dataset, please install Pandas package first. Then issue the following command:
```
python data_download.py
```
Arguments:
* `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/movielens-data/`.
* `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is `ml-1m`.
Use the `--help` or `-h` flag to get a full list of possible arguments.
Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (~10 mins) for data preprocessing.
### Train and evaluate model
To train and evaluate the model, issue the following command:
```
python ncf_main.py
```
Arguments:
* `--model_dir`: Directory to save model training checkpoints. By default, it is `/tmp/ncf/`.
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
* `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is `ml-1m`.
There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.
## Benchmarks (TODO)
### Training times
### Evaluation results
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NCF Constants."""
TRAIN_RATINGS_FILENAME = 'train-ratings.csv'
TEST_RATINGS_FILENAME = 'test-ratings.csv'
TEST_NEG_FILENAME = 'test-negative.csv'
TRAIN_DATA = 'train_data.csv'
TEST_DATA = 'test_data.csv'
USER = "user_id"
ITEM = "item_id"
RATING = "rating"
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and extract the MovieLens dataset from GroupLens website.
Download the dataset, and perform data-preprocessing to convert the raw dataset
into csv file to be used in model training and evaluation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import os
import sys
import time
import zipfile
import numpy as np
import pandas as pd
from six.moves import urllib # pylint: disable=redefined-builtin
import tensorflow as tf
from official.recommendation import constants # pylint: disable=g-bad-import-order
# URL to download dataset
_DATA_URL = "http://files.grouplens.org/datasets/movielens/"
_USER_COLUMN = "user_id"
_ITEM_COLUMN = "item_id"
_RATING_COLUMN = "rating"
_TIMESTAMP_COLUMN = "timestamp"
# The number of negative examples attached with a positive example
# in training dataset. It is set as 100 in the paper.
_NUMBER_NEGATIVES = 100
# In both datasets, each user has at least 20 ratings.
_MIN_NUM_RATINGS = 20
RatingData = collections.namedtuple(
"RatingData", ["items", "users", "ratings", "min_date", "max_date"])
def _print_ratings_description(ratings):
"""Describe the rating dataset information.
Args:
ratings: A pandas DataFrame of the rating dataset.
"""
info = RatingData(items=len(ratings[_ITEM_COLUMN].unique()),
users=len(ratings[_USER_COLUMN].unique()),
ratings=len(ratings),
min_date=ratings[_TIMESTAMP_COLUMN].min(),
max_date=ratings[_TIMESTAMP_COLUMN].max())
tf.logging.info("{ratings} ratings on {items} items from {users} users"
" from {min_date} to {max_date}".format(**(info._asdict())))
def process_movielens(ratings, sort=True):
"""Sort and convert timestamp of the MovieLens dataset.
Args:
ratings: A pandas DataFrame of the rating dataset.
sort: A boolean to indicate whether to sort the data based on timestamp.
Returns:
ratings: The processed pandas DataFrame.
"""
ratings[_TIMESTAMP_COLUMN] = pd.to_datetime(
ratings[_TIMESTAMP_COLUMN], unit="s")
if sort:
ratings.sort_values(by=_TIMESTAMP_COLUMN, inplace=True)
_print_ratings_description(ratings)
return ratings
def load_movielens_1_million(file_name, sort=True):
"""Load the MovieLens 1 million dataset.
The file has no header row, and each line is in the following format:
UserID::MovieID::Rating::Timestamp
- UserIDs range between 1 and 6040
- MovieIDs range between 1 and 3952
- Ratings are made on a 5-star scale (whole-star ratings only)
- Timestamp is represented in seconds since midnight Coordinated Universal
Time (UTC) of January 1, 1970.
- Each user has at least 20 ratings
Args:
file_name: A string of the file name to be loaded.
sort: A boolean to indicate whether to sort the data based on timestamp.
Returns:
A processed pandas DataFrame of the rating dataset.
"""
names = [_USER_COLUMN, _ITEM_COLUMN, _RATING_COLUMN, _TIMESTAMP_COLUMN]
ratings = pd.read_csv(file_name, sep="::", names=names, engine="python")
return process_movielens(ratings, sort=sort)
def load_movielens_20_million(file_name, sort=True):
"""Load the MovieLens 20 million dataset.
Each line of this file after the header row represents one rating of one movie
by one user, and has the following format:
userId,movieId,rating,timestamp
- The lines within this file are ordered first by userId, then, within user,
by movieId.
- Ratings are made on a 5-star scale, with half-star increments
(0.5 stars - 5.0 stars).
- Timestamps represent seconds since midnight Coordinated Universal Time
(UTC) of January 1, 1970.
- All the users had rated at least 20 movies.
Args:
file_name: A string of the file name to be loaded.
sort: A boolean to indicate whether to sort the data based on timestamp.
Returns:
A processed pandas DataFrame of the rating dataset.
"""
ratings = pd.read_csv(file_name)
names = {"userId": _USER_COLUMN, "movieId": _ITEM_COLUMN}
ratings.rename(columns=names, inplace=True)
return process_movielens(ratings, sort=sort)
def load_file_to_df(file_name, sort=True):
"""Load rating dataset into DataFrame.
Two data loading functions are defined to handle dataset ml-1m and ml-20m,
as they are provided with different formats.
Args:
file_name: A string of the file name to be loaded.
sort: A boolean to indicate whether to sort the data based on timestamp.
Returns:
A pandas DataFrame of the rating dataset.
"""
dataset_name = os.path.basename(file_name).split(".")[0]
# ml-1m with extension .dat
file_extension = ".dat"
func = load_movielens_1_million
if dataset_name == "ml-20m":
file_extension = ".csv"
func = load_movielens_20_million
ratings_file = os.path.join(file_name, "ratings" + file_extension)
return func(ratings_file, sort=sort)
def generate_train_eval_data(df, original_users, original_items):
"""Generate the dataset for model training and evaluation.
Given all user and item interaction information, for each user, first sort
the interactions based on timestamp. Then the latest one is taken out as
Test ratings (leave-one-out evaluation) and the remaining data for training.
The Test negatives are randomly sampled from all non-interacted items, and the
number of Test negatives is 100 by default (defined as _NUMBER_NEGATIVES).
Args:
df: The DataFrame of ratings data.
original_users: A list of the original unique user ids in the dataset.
original_items: A list of the original unique item ids in the dataset.
Returns:
all_ratings: A list of the [user_id, item_id] with interactions.
test_ratings: A list of [user_id, item_id], and each line is the latest
user_item interaction for the user.
test_negs: A list of item ids with shape [num_users, 100].
Each line consists of 100 item ids for the user with no interactions.
"""
# Need to sort before popping to get last item
tf.logging.info("Sorting user_item_map by timestamp...")
df.sort_values(by=_TIMESTAMP_COLUMN, inplace=True)
all_ratings = set(zip(df[_USER_COLUMN], df[_ITEM_COLUMN]))
user_to_items = collections.defaultdict(list)
# Generate user_item rating matrix for training
t1 = time.time()
row_count = 0
for row in df.itertuples():
user_to_items[getattr(row, _USER_COLUMN)].append(getattr(row, _ITEM_COLUMN))
row_count += 1
if row_count % 50000 == 0:
tf.logging.info("Processing user_to_items row: {}".format(row_count))
tf.logging.info(
"Process {} rows in [{:.1f}]s".format(row_count, time.time() - t1))
# Generate test ratings and test negatives
t2 = time.time()
test_ratings = []
test_negs = []
# Generate the 0-based index for each item, and put it into a set
all_items = set(range(len(original_items)))
for user in range(len(original_users)):
test_item = user_to_items[user].pop() # Get the latest item id
all_ratings.remove((user, test_item)) # Remove the test item
all_negs = all_items.difference(user_to_items[user])
all_negs = sorted(list(all_negs)) # determinism
test_ratings.append((user, test_item))
test_negs.append(list(np.random.choice(all_negs, _NUMBER_NEGATIVES)))
if user % 1000 == 0:
tf.logging.info("Processing user: {}".format(user))
tf.logging.info("Process {} users in {:.1f}s".format(
len(original_users), time.time() - t2))
all_ratings = list(all_ratings) # convert set to list
return all_ratings, test_ratings, test_negs
def parse_file_to_csv(data_dir, dataset_name):
"""Parse the raw data to csv file to be used in model training and evaluation.
ml-1m dataset is small in size (~25M), while ml-20m is large (~500M). It may
take several minutes to process ml-20m dataset.
Args:
data_dir: A string, the directory with the unzipped dataset.
dataset_name: A string, the dataset name to be processed.
"""
# Use random seed as parameter
np.random.seed(0)
# Load the file as DataFrame
file_path = os.path.join(data_dir, dataset_name)
df = load_file_to_df(file_path, sort=False)
# Get the info of users who have more than 20 ratings on items
grouped = df.groupby(_USER_COLUMN)
df = grouped.filter(lambda x: len(x) >= _MIN_NUM_RATINGS)
original_users = df[_USER_COLUMN].unique()
original_items = df[_ITEM_COLUMN].unique()
# Map the ids of user and item to 0 based index for following processing
tf.logging.info("Generating user_map and item_map...")
user_map = {user: index for index, user in enumerate(original_users)}
item_map = {item: index for index, item in enumerate(original_items)}
df[_USER_COLUMN] = df[_USER_COLUMN].apply(lambda user: user_map[user])
df[_ITEM_COLUMN] = df[_ITEM_COLUMN].apply(lambda item: item_map[item])
assert df[_USER_COLUMN].max() == len(original_users) - 1
assert df[_ITEM_COLUMN].max() == len(original_items) - 1
# Generate data for train and test
all_ratings, test_ratings, test_negs = generate_train_eval_data(
df, original_users, original_items)
# Serialize to csv file. Each csv file contains three columns
# (user_id, item_id, interaction)
# As there are only two fields (user_id, item_id) in all_ratings and
# test_ratings, we need to add a fake rating to make three columns
df_train_ratings = pd.DataFrame(all_ratings)
df_train_ratings["fake_rating"] = 1
train_ratings_file = os.path.join(
FLAGS.data_dir, dataset_name + "-" + constants.TRAIN_RATINGS_FILENAME)
df_train_ratings.to_csv(
train_ratings_file,
index=False, header=False, sep="\t")
tf.logging.info("Train ratings is {}".format(train_ratings_file))
df_test_ratings = pd.DataFrame(test_ratings)
df_test_ratings["fake_rating"] = 1
test_ratings_file = os.path.join(
FLAGS.data_dir, dataset_name + "-" + constants.TEST_RATINGS_FILENAME)
df_test_ratings.to_csv(
test_ratings_file,
index=False, header=False, sep="\t")
tf.logging.info("Test ratings is {}".format(test_ratings_file))
df_test_negs = pd.DataFrame(test_negs)
test_negs_file = os.path.join(
FLAGS.data_dir, dataset_name + "-" + constants.TEST_NEG_FILENAME)
df_test_negs.to_csv(
test_negs_file,
index=False, header=False, sep="\t")
tf.logging.info("Test negatives is {}".format(test_negs_file))
def make_dir(file_dir):
if not tf.gfile.Exists(file_dir):
tf.logging.info("Creating directory {}".format(file_dir))
tf.gfile.MakeDirs(file_dir)
def main(_):
"""Download and extract the data from GroupLens website."""
tf.logging.set_verbosity(tf.logging.INFO)
make_dir(FLAGS.data_dir)
# Download the zip dataset
dataset_zip = FLAGS.dataset + ".zip"
file_path = os.path.join(FLAGS.data_dir, dataset_zip)
if not tf.gfile.Exists(file_path):
def _progress(count, block_size, total_size):
sys.stdout.write("\r>> Downloading {} {:.1f}%".format(
file_path, 100.0 * count * block_size / total_size))
sys.stdout.flush()
file_path, _ = urllib.request.urlretrieve(
_DATA_URL + dataset_zip, file_path, _progress)
statinfo = os.stat(file_path)
# A new line to clear the carriage return from download progress
# tf.logging.info is not applicable here
print()
tf.logging.info(
"Successfully downloaded {} {} bytes".format(
file_path, statinfo.st_size))
# Unzip the dataset
if not tf.gfile.Exists(os.path.join(FLAGS.data_dir, FLAGS.dataset)):
zipfile.ZipFile(file_path, "r").extractall(FLAGS.data_dir)
# Preprocess and parse the dataset to csv
train_ratings = FLAGS.dataset + "-" + constants.TRAIN_RATINGS_FILENAME
if not tf.gfile.Exists(os.path.join(FLAGS.data_dir, train_ratings)):
parse_file_to_csv(FLAGS.data_dir, FLAGS.dataset)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", type=str, default="/tmp/movielens-data/",
help="Directory to download data and extract the zip.")
parser.add_argument(
"--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"],
help="Dataset to be trained and evaluated.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Prepare dataset for NCF.
Load the training dataset and evaluation dataset from csv file into memory.
Prepare input for model training and evaluation.
"""
import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from official.recommendation import constants # pylint: disable=g-bad-import-order
# The column names and types of csv file
_CSV_COLUMN_NAMES = [constants.USER, constants.ITEM, constants.RATING]
_CSV_TYPES = [[0], [0], [0]]
# The buffer size for shuffling train dataset.
_SHUFFLE_BUFFER_SIZE = 1024
class NCFDataSet(object):
"""A class containing data information for model training and evaluation."""
def __init__(self, train_data, num_users, num_items, num_negatives,
true_items, all_items):
"""Initialize NCFDataset class.
Args:
train_data: A list containing the positive training instances.
num_users: An integer, the number of users in training dataset.
num_items: An integer, the number of items in training dataset.
num_negatives: An integer, the number of negative instances for each user
in train dataset.
true_items: A list, the ground truth (positive) items of users for
evaluation. Each entry is a latest positive instance for one user.
all_items: A nested list, all items for evaluation, and each entry is the
evaluation items for one user.
"""
self.train_data = train_data
self.num_users = num_users
self.num_items = num_items
self.num_negatives = num_negatives
self.eval_true_items = true_items
self.eval_all_items = all_items
def load_data(file_name):
"""Load data from a csv file which splits on \t."""
lines = tf.gfile.Open(file_name, "r").readlines()
# Process the file line by line
def _process_line(line):
return [int(col) for col in line.split("\t")]
data = [_process_line(line) for line in lines]
return data
def data_preprocessing(train_fname, test_fname, test_neg_fname, num_negatives):
"""Preprocess the train and test dataset.
In data preprocessing, the training positive instances are loaded into memory
for random negative instance generation in each training epoch. The test
dataset are generated from test positive and negative instances.
Args:
train_fname: A string, the file name of training positive dataset.
test_fname: A string, the file name of test positive dataset. Each user has
one positive instance.
test_neg_fname: A string, the file name of test negative dataset. Each user
has 100 negative instances by default.
num_negatives: An integer, the number of negative instances for each user
in train dataset.
Returns:
ncf_dataset: A NCFDataset object containing information about training and
evaluation/test dataset.
"""
# Load training positive instances into memory for later train data generation
train_data = load_data(train_fname)
# Get total number of users in the dataset
num_users = len(np.unique(np.array(train_data)[:, 0]))
# Process test dataset to csv file
test_ratings = load_data(test_fname)
test_negatives = load_data(test_neg_fname)
# Get the total number of items in both train dataset and test dataset (the
# whole dataset)
num_items = len(
set(np.array(train_data)[:, 1]) | set(np.array(test_ratings)[:, 1]))
# Generate test instances for each user
true_items, all_items = [], []
all_test_data = []
for idx in range(num_users):
items = test_negatives[idx]
rating = test_ratings[idx]
user = rating[0] # User
true_item = rating[1] # Positive item as ground truth
# All items with first 100 as negative and last one positive
items.append(true_item)
users = np.full(len(items), user, dtype=np.int32)
users_items = list(zip(users, items)) # User-item list
true_items.append(true_item) # all ground truth items
all_items.append(items) # All items (including positive and negative items)
all_test_data.extend(users_items) # Generate test dataset
# Save test dataset into csv file
np.savetxt(constants.TEST_DATA, np.asarray(all_test_data).astype(int),
fmt="%i", delimiter=",")
# Create NCFDataset object
ncf_dataset = NCFDataSet(
train_data, num_users, num_items, num_negatives, true_items, all_items)
return ncf_dataset
def generate_train_dataset(train_data, num_items, num_negatives):
"""Generate train dataset for each epoch.
Given positive training instances, randomly generate negative instances to
form the training dataset.
Args:
train_data: A list of positive training instances.
num_items: An integer, the number of items in positive training instances.
num_negatives: An integer, the number of negative training instances
following positive training instances. It is 4 by default.
"""
all_train_data = []
# A set with user-item tuples
train_data_set = set((u, i) for u, i, _ in train_data)
for u, i, _ in train_data:
# Positive instance
all_train_data.append([u, i, 1])
# Negative instances, randomly generated
for _ in xrange(num_negatives):
j = np.random.randint(num_items)
while (u, j) in train_data_set:
j = np.random.randint(num_items)
all_train_data.append([u, j, 0])
# Save the train dataset into a csv file
np.savetxt(constants.TRAIN_DATA, np.asarray(all_train_data).astype(int),
fmt="%i", delimiter=",")
def input_fn(training, batch_size, repeat=1, ncf_dataset=None,
num_parallel_calls=1):
"""Input function for model training and evaluation.
The train input consists of 1 positive instance (user and item have
interactions) followed by some number of negative instances in which the items
are randomly chosen. The number of negative instances is "num_negatives" which
is 4 by default. Note that for each epoch, we need to re-generate the negative
instances. Together with positive instances, they form a new train dataset.
Args:
training: A boolean flag for training mode.
batch_size: An integer, batch size for training and evaluation.
repeat: An integer, how many times to repeat the dataset.
ncf_dataset: An NCFDataSet object, which contains the information to
generate negative training instances.
num_parallel_calls: An integer, number of cpu cores for parallel input
processing.
Returns:
dataset: A tf.data.Dataset object containing examples loaded from the files.
"""
# Default test file name
file_name = constants.TEST_DATA
# Generate random negative instances for training in each epoch
if training:
t1 = time.time()
generate_train_dataset(
ncf_dataset.train_data, ncf_dataset.num_items,
ncf_dataset.num_negatives)
file_name = constants.TRAIN_DATA
tf.logging.info(
"Generating training instances: {:.1f}s".format(time.time() - t1))
# Create a dataset containing the text lines.
dataset = tf.data.TextLineDataset(file_name)
# Test dataset only has two fields while train dataset has three
num_cols = len(_CSV_COLUMN_NAMES) - 1
# Shuffle the dataset for training
if training:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER_SIZE)
num_cols += 1
def _parse_csv(line):
"""Parse each line of the csv file."""
# Decode the line into its fields
fields = tf.decode_csv(line, record_defaults=_CSV_TYPES[0:num_cols])
fields = [tf.expand_dims(field, axis=0) for field in fields]
# Pack the result into a dictionary
features = dict(zip(_CSV_COLUMN_NAMES[0:num_cols], fields))
# Separate the labels from the features for training
if training:
labels = features.pop(constants.RATING)
return features, labels
# Return features only for test/prediction
return features
# Parse each line into a dictionary
dataset = dataset.map(_parse_csv, num_parallel_calls=num_parallel_calls)
# Repeat and batch the dataset
dataset = dataset.repeat(repeat)
dataset = dataset.batch(batch_size)
# Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(1)
return dataset
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""NCF framework to train and evaluate the NeuMF model.
The NeuMF model assembles both MF and MLP models under the NCF framework. Check
`neumf_model.py` for more details about the models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import heapq
import math
import os
import sys
import time
import numpy as np
import tensorflow as tf
# pylint: disable=g-bad-import-order
from official.recommendation import constants
from official.recommendation import dataset
from official.recommendation import neumf_model
_TOP_K = 10 # Top-k list for evaluation
_EVAL_BATCH_SIZE = 100
def evaluate_model(estimator, batch_size, num_gpus, true_items, all_items,
num_parallel_calls):
"""Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items)
among the randomly chosen 100 items that are not interacted by the user.
The performance of the ranked list is judged by Hit Ratio (HR) and Normalized
Discounted Cumulative Gain (NDCG).
For evaluation, the ranked list is truncated at 10 for both metrics. As such,
the HR intuitively measures whether the test item is present on the top-10
list, and the NDCG accounts for the position of the hit by assigning higher
scores to hits at top ranks. Both metrics are calculated for each test user,
and the average scores are reported.
Args:
estimator: The Estimator.
batch_size: An integer, the batch size specified by user.
num_gpus: An integer, the number of gpus specified by user.
true_items: A list of test items (true items) for HR and NDCG calculation.
Each item is for one user.
all_items: A nested list. Each entry is the 101 items (1 ground truth item
and 100 negative items) for one user.
num_parallel_calls: An integer, number of cpu cores for parallel input
processing in input_fn.
Returns:
hit: An integer, the average HR scores for all users.
ndcg: An integer, the average NDCG scores for all users.
"""
# Define prediction input function
def pred_input_fn():
return dataset.input_fn(
False, per_device_batch_size(batch_size, num_gpus),
num_parallel_calls=num_parallel_calls)
# Get predictions
predictions = estimator.predict(input_fn=pred_input_fn)
all_predicted_scores = [p[constants.RATING] for p in predictions]
# Calculate HR score
def _get_hr(ranklist, true_item):
return 1 if true_item in ranklist else 0
# Calculate NDCG score
def _get_ndcg(ranklist, true_item):
if true_item in ranklist:
return math.log(2) / math.log(ranklist.index(true_item) + 2)
return 0
hits, ndcgs = [], []
num_users = len(true_items)
# Reshape the predicted scores and each user takes one row
predicted_scores_list = np.asarray(
all_predicted_scores).reshape(num_users, -1)
for i in range(num_users):
items = all_items[i]
predicted_scores = predicted_scores_list[i]
# Map item and score for each user
map_item_score = {}
for j, item in enumerate(items):
score = predicted_scores[j]
map_item_score[item] = score
# Evaluate top rank list with HR and NDCG
ranklist = heapq.nlargest(_TOP_K, map_item_score, key=map_item_score.get)
true_item = true_items[i]
hr = _get_hr(ranklist, true_item)
ndcg = _get_ndcg(ranklist, true_item)
hits.append(hr)
ndcgs.append(ndcg)
# Get average HR and NDCG scores
hr, ndcg = np.array(hits).mean(), np.array(ndcgs).mean()
return hr, ndcg
def get_num_gpus(num_gpus):
"""Treat num_gpus=-1 as "use all"."""
if num_gpus != -1:
return num_gpus
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
local_device_protos = device_lib.list_local_devices()
return sum([1 for d in local_device_protos if d.device_type == "GPU"])
def convert_keras_to_estimator(keras_model, num_gpus, model_dir):
"""Configure and convert keras model to Estimator.
Args:
keras_model: A Keras model object.
num_gpus: An integer, the number of gpus.
model_dir: A string, the directory to save and restore checkpoints.
Returns:
est_model: The converted Estimator.
"""
# TODO(b/79866338): update GradientDescentOptimizer with AdamOptimizer
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=FLAGS.learning_rate)
keras_model.compile(optimizer=optimizer, loss="binary_crossentropy")
if num_gpus == 0:
distribution = tf.contrib.distribute.OneDeviceStrategy("device:CPU:0")
elif num_gpus == 1:
distribution = tf.contrib.distribute.OneDeviceStrategy("device:GPU:0")
else:
distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=num_gpus)
run_config = tf.estimator.RunConfig(train_distribute=distribution)
estimator = tf.keras.estimator.model_to_estimator(
keras_model=keras_model, model_dir=model_dir, config=run_config)
return estimator
def per_device_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by DistributionStrategies
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size
remainder = batch_size % num_gpus
if remainder:
err = ("When running with multiple GPUs, batch size "
"must be a multiple of the number of available GPUs. Found {} "
"GPUs with a batch size of {}; try --batch_size={} instead."
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
def main(_):
# Data preprocessing
# The file name of training and test dataset
train_fname = os.path.join(
FLAGS.data_dir, FLAGS.dataset + "-" + constants.TRAIN_RATINGS_FILENAME)
test_fname = os.path.join(
FLAGS.data_dir, FLAGS.dataset + "-" + constants.TEST_RATINGS_FILENAME)
neg_fname = os.path.join(
FLAGS.data_dir, FLAGS.dataset + "-" + constants.TEST_NEG_FILENAME)
t1 = time.time()
ncf_dataset = dataset.data_preprocessing(
train_fname, test_fname, neg_fname, FLAGS.num_neg)
tf.logging.info("Data preprocessing: {:.1f} s".format(time.time() - t1))
# Create NeuMF model and convert it to Estimator
tf.logging.info("Creating Estimator from Keras model...")
keras_model = neumf_model.NeuMF(
ncf_dataset.num_users, ncf_dataset.num_items, FLAGS.num_factors,
ast.literal_eval(FLAGS.layers), FLAGS.batch_size, FLAGS.mf_regularization)
num_gpus = get_num_gpus(FLAGS.num_gpus)
estimator = convert_keras_to_estimator(keras_model, num_gpus, FLAGS.model_dir)
# Training and evaluation cycle
def train_input_fn():
return dataset.input_fn(
True, per_device_batch_size(FLAGS.batch_size, num_gpus),
FLAGS.epochs_between_evals, ncf_dataset, FLAGS.num_parallel_calls)
total_training_cycle = (FLAGS.train_epochs //
FLAGS.epochs_between_evals)
for cycle_index in range(total_training_cycle):
tf.logging.info("Starting a training cycle: {}/{}".format(
cycle_index, total_training_cycle - 1))
# Train the model
train_cycle_begin = time.time()
estimator.train(input_fn=train_input_fn,
hooks=[tf.train.ProfilerHook(save_steps=10000)])
train_cycle_end = time.time()
# Evaluate the model
eval_cycle_begin = time.time()
hr, ndcg = evaluate_model(
estimator, FLAGS.batch_size, num_gpus, ncf_dataset.eval_true_items,
ncf_dataset.eval_all_items, FLAGS.num_parallel_calls)
eval_cycle_end = time.time()
# Log the train time, evaluation time, and HR and NDCG results.
tf.logging.info(
"Iteration {} [{:.1f} s]: HR = {:.4f}, NDCG = {:.4f}, [{:.1f} s]"
.format(cycle_index, train_cycle_end - train_cycle_begin, hr, ndcg,
eval_cycle_end - eval_cycle_begin))
# Remove temporary files
os.remove(constants.TRAIN_DATA)
os.remove(constants.TEST_DATA)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir", nargs="?", default="/tmp/ncf/",
help="Model directory.")
parser.add_argument(
"--data_dir", nargs="?", default="/tmp/movielens-data/",
help="Input data directory. Should be the same as downloaded data dir.")
parser.add_argument(
"--dataset", nargs="?", default="ml-1m", choices=["ml-1m", "ml-20m"],
help="Choose a dataset.")
parser.add_argument(
"--train_epochs", type=int, default=20,
help="Number of epochs.")
parser.add_argument(
"--batch_size", type=int, default=256,
help="Batch size.")
parser.add_argument(
"--num_factors", type=int, default=8,
help="Embedding size of MF model.")
parser.add_argument(
"--layers", nargs="?", default="[64,32,16,8]",
help="Size of hidden layers for MLP.")
parser.add_argument(
"--mf_regularization", type=float, default=0,
help="Regularization for MF embeddings.")
parser.add_argument(
"--num_neg", type=int, default=4,
help="Number of negative instances to pair with a positive instance.")
parser.add_argument(
"--num_parallel_calls", type=int, default=8,
help="Number of parallel calls.")
parser.add_argument(
"--epochs_between_evals", type=int, default=1,
help="Number of epochs between model evaluation.")
parser.add_argument(
"--learning_rate", type=float, default=0.001,
help="Learning rate.")
parser.add_argument(
"--num_gpus", type=int, default=1 if tf.test.is_gpu_available() else 0,
help="How many GPUs to use with the DistributionStrategies API. The "
"default is 1 if TensorFlow can detect a GPU, and 0 otherwise.")
FLAGS, unparsed = parser.parse_known_args()
with tf.Graph().as_default():
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines NeuMF model for NCF framework.
Some abbreviations used in the code base:
NeuMF: Neural Matrix Factorization
NCF: Neural Collaborative Filtering
GMF: Generalized Matrix Factorization
MLP: Multi-Layer Perceptron
GMF applies a linear kernel to model the latent feature interactions, and MLP
uses a nonlinear kernel to learn the interaction function from data. NeuMF model
is a fused model of GMF and MLP to better model the complex user-item
interactions, and unifies the strengths of linearity of MF and non-linearity of
MLP for modeling the user-item latent structures.
In NeuMF model, it allows GMF and MLP to learn separate embeddings, and combine
the two models by concatenating their last hidden layer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from official.recommendation import constants # pylint: disable=g-bad-import-order
class NeuMF(tf.keras.models.Model):
"""Neural matrix factorization (NeuMF) model for recommendations."""
def __init__(self, num_users, num_items, mf_dim, model_layers, batch_size,
mf_regularization=0):
"""Initialize NeuMF model.
Args:
num_users: An integer, the number of users.
num_items: An integer, the number of items.
mf_dim: An integer, the embedding size of Matrix Factorization (MF) model.
model_layers: A list of integers for Multi-Layer Perceptron (MLP) layers.
Note that the first layer is the concatenation of user and item
embeddings. So model_layers[0]//2 is the embedding size for MLP.
batch_size: An integer for the batch size.
mf_regularization: A floating number, the regularization for MF
embeddings.
Raises:
ValueError: if the first model layer is not even.
"""
if model_layers[0] % 2 != 0:
raise ValueError("The first layer size should be multiple of 2!")
# Input variables
user_input = tf.keras.layers.Input(
shape=(1,), dtype=tf.int32, name=constants.USER)
item_input = tf.keras.layers.Input(
shape=(1,), dtype=tf.int32, name=constants.ITEM)
# Initializer for embedding layer
embedding_initializer = tf.keras.initializers.RandomNormal(stddev=0.01)
# Embedding layers of GMF and MLP
mf_embedding_user = tf.keras.layers.Embedding(
num_users,
mf_dim,
embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
input_length=1)
mf_embedding_item = tf.keras.layers.Embedding(
num_items,
mf_dim,
embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(mf_regularization),
input_length=1)
mlp_embedding_user = tf.keras.layers.Embedding(
num_users,
model_layers[0]//2,
embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(model_layers[0]),
input_length=1)
mlp_embedding_item = tf.keras.layers.Embedding(
num_items,
model_layers[0]//2,
embeddings_initializer=embedding_initializer,
embeddings_regularizer=tf.keras.regularizers.l2(model_layers[0]),
input_length=1)
# GMF part
# Flatten the embedding vector as latent features in GMF
mf_user_latent = tf.keras.layers.Flatten()(mf_embedding_user(user_input))
mf_item_latent = tf.keras.layers.Flatten()(mf_embedding_item(item_input))
# Element-wise multiply
mf_vector = tf.keras.layers.multiply([mf_user_latent, mf_item_latent])
# MLP part
# Flatten the embedding vector as latent features in MLP
mlp_user_latent = tf.keras.layers.Flatten()(mlp_embedding_user(user_input))
mlp_item_latent = tf.keras.layers.Flatten()(mlp_embedding_item(item_input))
# Concatenation of two latent features
mlp_vector = tf.keras.layers.concatenate([mlp_user_latent, mlp_item_latent])
num_layer = len(model_layers) # Number of layers in the MLP
for idx in xrange(1, num_layer):
model_layer = tf.keras.layers.Dense(
model_layers[idx],
activation="relu")
mlp_vector = model_layer(mlp_vector)
# Concatenate GMF and MLP parts
predict_vector = tf.keras.layers.concatenate([mf_vector, mlp_vector])
# Final prediction layer
prediction = tf.keras.layers.Dense(
1, activation="sigmoid", kernel_initializer="lecun_uniform",
name=constants.RATING)(predict_vector)
super(NeuMF, self).__init__(
inputs=[user_input, item_input], outputs=prediction)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment