process_movielens1m.py 6.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Script that reads from raw MovieLens-1M data and dumps into a pickle
file the following:

* A heterogeneous graph with categorical features.
* A list with all the movie titles.  The movie titles correspond to
  the movie nodes in the heterogeneous graph.

This script exemplifies how to prepare tabular data with textual
features.  Since DGL graphs do not store variable-length features, we
instead put variable-length features into a more suitable container
(e.g. torchtext to handle list of texts)
"""

import argparse
16
import os
17
import pickle
18
19
import re

20
import numpy as np
21
import pandas as pd
22
23
24
25
26
27
import scipy.sparse as ssp
import torch
import torchtext
from builder import PandasGraphBuilder
from data_utils import *

28
29
30
import dgl

if __name__ == "__main__":
31
    parser = argparse.ArgumentParser()
32
33
    parser.add_argument("directory", type=str)
    parser.add_argument("out_directory", type=str)
34
35
    args = parser.parse_args()
    directory = args.directory
36
37
    out_directory = args.out_directory
    os.makedirs(out_directory, exist_ok=True)
38
39
40
41
42

    ## Build heterogeneous graph

    # Load data
    users = []
43
    with open(os.path.join(directory, "users.dat"), encoding="latin1") as f:
44
        for l in f:
45
46
47
48
49
50
51
52
53
54
55
            id_, gender, age, occupation, zip_ = l.strip().split("::")
            users.append(
                {
                    "user_id": int(id_),
                    "gender": gender,
                    "age": age,
                    "occupation": occupation,
                    "zip": zip_,
                }
            )
    users = pd.DataFrame(users).astype("category")
56
57

    movies = []
58
    with open(os.path.join(directory, "movies.dat"), encoding="latin1") as f:
59
        for l in f:
60
61
            id_, title, genres = l.strip().split("::")
            genres_set = set(genres.split("|"))
62
63

            # extract year
64
            assert re.match(r".*\([0-9]{4}\)$", title)
65
66
67
            year = title[-5:-1]
            title = title[:-6].strip()

68
            data = {"movie_id": int(id_), "title": title, "year": year}
69
70
71
            for g in genres_set:
                data[g] = True
            movies.append(data)
72
    movies = pd.DataFrame(movies).astype({"year": "category"})
73
74

    ratings = []
75
    with open(os.path.join(directory, "ratings.dat"), encoding="latin1") as f:
76
        for l in f:
77
78
79
80
81
82
83
84
85
86
87
            user_id, movie_id, rating, timestamp = [
                int(_) for _ in l.split("::")
            ]
            ratings.append(
                {
                    "user_id": user_id,
                    "movie_id": movie_id,
                    "rating": rating,
                    "timestamp": timestamp,
                }
            )
88
89
90
    ratings = pd.DataFrame(ratings)

    # Filter the users and items that never appear in the rating table.
91
92
93
94
    distinct_users_in_ratings = ratings["user_id"].unique()
    distinct_movies_in_ratings = ratings["movie_id"].unique()
    users = users[users["user_id"].isin(distinct_users_in_ratings)]
    movies = movies[movies["movie_id"].isin(distinct_movies_in_ratings)]
95
96

    # Group the movie features into genres (a vector), year (a category), title (a string)
97
98
99
    genre_columns = movies.columns.drop(["movie_id", "title", "year"])
    movies[genre_columns] = movies[genre_columns].fillna(False).astype("bool")
    movies_categorical = movies.drop("title", axis=1)
100
101
102

    # Build graph
    graph_builder = PandasGraphBuilder()
103
104
105
106
107
108
109
110
    graph_builder.add_entities(users, "user_id", "user")
    graph_builder.add_entities(movies_categorical, "movie_id", "movie")
    graph_builder.add_binary_relations(
        ratings, "user_id", "movie_id", "watched"
    )
    graph_builder.add_binary_relations(
        ratings, "movie_id", "user_id", "watched-by"
    )
111
112
113
114
115

    g = graph_builder.build()

    # Assign features.
    # Note that variable-sized features such as texts or images are handled elsewhere.
116
117
118
119
    for data_type in ["gender", "age", "occupation", "zip"]:
        g.nodes["user"].data[data_type] = torch.LongTensor(
            np.array(users[data_type].cat.codes.values)
        )
120
121

    g.nodes["movie"].data["year"] = torch.LongTensor(
122
        np.array(movies["year"].cat.codes.values)
123
124
    )
    g.nodes["movie"].data["genre"] = torch.FloatTensor(
125
        np.array(movies[genre_columns].values)
126
127
    )

128
129
130
131
132
    for edge_type in ["watched", "watched-by"]:
        for data_type in ["rating", "timestamp"]:
            g.edges[edge_type].data[data_type] = torch.LongTensor(
                np.array(ratings[data_type].values)
            )
133
134
135
136

    # Train-validation-test split
    # This is a little bit tricky as we want to select the last interaction for test, and the
    # second-to-last interaction for validation.
137
138
139
    train_indices, val_indices, test_indices = train_test_split_by_time(
        ratings, "timestamp", "user_id"
    )
140
141

    # Build the graph with training interactions only.
142
143
144
145
    train_g = build_train_graph(
        g, train_indices, "user", "movie", "watched", "watched-by"
    )
    assert train_g.out_degrees(etype="watched").min() > 0
146
147

    # Build the user-item sparse matrix for validation and test set.
148
149
150
    val_matrix, test_matrix = build_val_test_matrix(
        g, val_indices, test_indices, "user", "movie", "watched"
    )
151
152
153

    ## Build title set

154
    movie_textual_dataset = {"title": movies["title"].values}
155
156
157
158
159
160
161
162
163
164
165

    # The model should build their own vocabulary and process the texts.  Here is one example
    # of using torchtext to pad and numericalize a batch of strings.
    #     field = torchtext.data.Field(include_lengths=True, lower=True, batch_first=True)
    #     examples = [torchtext.data.Example.fromlist([t], [('title', title_field)]) for t in texts]
    #     titleset = torchtext.data.Dataset(examples, [('title', title_field)])
    #     field.build_vocab(titleset.title, vectors='fasttext.simple.300d')
    #     token_ids, lengths = field.process([examples[0].title, examples[1].title])

    ## Dump the graph and the datasets

166
    dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
167

168
    dataset = {
169
170
171
172
173
174
175
176
177
178
179
180
        "val-matrix": val_matrix,
        "test-matrix": test_matrix,
        "item-texts": movie_textual_dataset,
        "item-images": None,
        "user-type": "user",
        "item-type": "movie",
        "user-to-item-type": "watched",
        "item-to-user-type": "watched-by",
        "timestamp-edge-column": "timestamp",
    }

    with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
181
        pickle.dump(dataset, f)