process_movielens1m.py 6.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Script that reads from raw MovieLens-1M data and dumps into a pickle
file the following:

* A heterogeneous graph with categorical features.
* A list with all the movie titles.  The movie titles correspond to
  the movie nodes in the heterogeneous graph.

This script exemplifies how to prepare tabular data with textual
features.  Since DGL graphs do not store variable-length features, we
instead put variable-length features into a more suitable container
(e.g. torchtext to handle list of texts)
"""

import argparse
16
import os
17
import pickle
18
19
import re

20
import numpy as np
21
import pandas as pd
22
23
24
25
26
27
import scipy.sparse as ssp
import torch
import torchtext
from builder import PandasGraphBuilder
from data_utils import *

28
29
30
import dgl

if __name__ == "__main__":
31
    parser = argparse.ArgumentParser()
32
33
    parser.add_argument("directory", type=str)
    parser.add_argument("out_directory", type=str)
34
35
    args = parser.parse_args()
    directory = args.directory
36
37
    out_directory = args.out_directory
    os.makedirs(out_directory, exist_ok=True)
38
39
40
41
42

    ## Build heterogeneous graph

    # Load data
    users = []
43
    with open(os.path.join(directory, "users.dat"), encoding="latin1") as f:
44
        for l in f:
45
46
47
48
49
50
51
52
53
54
55
            id_, gender, age, occupation, zip_ = l.strip().split("::")
            users.append(
                {
                    "user_id": int(id_),
                    "gender": gender,
                    "age": age,
                    "occupation": occupation,
                    "zip": zip_,
                }
            )
    users = pd.DataFrame(users).astype("category")
56
57

    movies = []
58
    with open(os.path.join(directory, "movies.dat"), encoding="latin1") as f:
59
        for l in f:
60
61
            id_, title, genres = l.strip().split("::")
            genres_set = set(genres.split("|"))
62
63

            # extract year
64
            assert re.match(r".*\([0-9]{4}\)$", title)
65
66
67
            year = title[-5:-1]
            title = title[:-6].strip()

68
            data = {"movie_id": int(id_), "title": title, "year": year}
69
70
71
            for g in genres_set:
                data[g] = True
            movies.append(data)
72
    movies = pd.DataFrame(movies).astype({"year": "category"})
73
74

    ratings = []
75
    with open(os.path.join(directory, "ratings.dat"), encoding="latin1") as f:
76
        for l in f:
77
78
79
80
81
82
83
84
85
86
87
            user_id, movie_id, rating, timestamp = [
                int(_) for _ in l.split("::")
            ]
            ratings.append(
                {
                    "user_id": user_id,
                    "movie_id": movie_id,
                    "rating": rating,
                    "timestamp": timestamp,
                }
            )
88
89
90
    ratings = pd.DataFrame(ratings)

    # Filter the users and items that never appear in the rating table.
91
92
93
94
    distinct_users_in_ratings = ratings["user_id"].unique()
    distinct_movies_in_ratings = ratings["movie_id"].unique()
    users = users[users["user_id"].isin(distinct_users_in_ratings)]
    movies = movies[movies["movie_id"].isin(distinct_movies_in_ratings)]
95
96

    # Group the movie features into genres (a vector), year (a category), title (a string)
97
98
99
    genre_columns = movies.columns.drop(["movie_id", "title", "year"])
    movies[genre_columns] = movies[genre_columns].fillna(False).astype("bool")
    movies_categorical = movies.drop("title", axis=1)
100
101
102

    # Build graph
    graph_builder = PandasGraphBuilder()
103
104
105
106
107
108
109
110
    graph_builder.add_entities(users, "user_id", "user")
    graph_builder.add_entities(movies_categorical, "movie_id", "movie")
    graph_builder.add_binary_relations(
        ratings, "user_id", "movie_id", "watched"
    )
    graph_builder.add_binary_relations(
        ratings, "movie_id", "user_id", "watched-by"
    )
111
112
113
114
115

    g = graph_builder.build()

    # Assign features.
    # Note that variable-sized features such as texts or images are handled elsewhere.
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    g.nodes["user"].data["gender"] = torch.LongTensor(
        users["gender"].cat.codes.values
    )
    g.nodes["user"].data["age"] = torch.LongTensor(
        users["age"].cat.codes.values
    )
    g.nodes["user"].data["occupation"] = torch.LongTensor(
        users["occupation"].cat.codes.values
    )
    g.nodes["user"].data["zip"] = torch.LongTensor(
        users["zip"].cat.codes.values
    )

    g.nodes["movie"].data["year"] = torch.LongTensor(
        movies["year"].cat.codes.values
    )
    g.nodes["movie"].data["genre"] = torch.FloatTensor(
        movies[genre_columns].values
    )

    g.edges["watched"].data["rating"] = torch.LongTensor(
        ratings["rating"].values
    )
    g.edges["watched"].data["timestamp"] = torch.LongTensor(
        ratings["timestamp"].values
    )
    g.edges["watched-by"].data["rating"] = torch.LongTensor(
        ratings["rating"].values
    )
    g.edges["watched-by"].data["timestamp"] = torch.LongTensor(
        ratings["timestamp"].values
    )
148
149
150
151

    # Train-validation-test split
    # This is a little bit tricky as we want to select the last interaction for test, and the
    # second-to-last interaction for validation.
152
153
154
    train_indices, val_indices, test_indices = train_test_split_by_time(
        ratings, "timestamp", "user_id"
    )
155
156

    # Build the graph with training interactions only.
157
158
159
160
    train_g = build_train_graph(
        g, train_indices, "user", "movie", "watched", "watched-by"
    )
    assert train_g.out_degrees(etype="watched").min() > 0
161
162

    # Build the user-item sparse matrix for validation and test set.
163
164
165
    val_matrix, test_matrix = build_val_test_matrix(
        g, val_indices, test_indices, "user", "movie", "watched"
    )
166
167
168

    ## Build title set

169
    movie_textual_dataset = {"title": movies["title"].values}
170
171
172
173
174
175
176
177
178
179
180

    # The model should build their own vocabulary and process the texts.  Here is one example
    # of using torchtext to pad and numericalize a batch of strings.
    #     field = torchtext.data.Field(include_lengths=True, lower=True, batch_first=True)
    #     examples = [torchtext.data.Example.fromlist([t], [('title', title_field)]) for t in texts]
    #     titleset = torchtext.data.Dataset(examples, [('title', title_field)])
    #     field.build_vocab(titleset.title, vectors='fasttext.simple.300d')
    #     token_ids, lengths = field.process([examples[0].title, examples[1].title])

    ## Dump the graph and the datasets

181
    dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
182

183
    dataset = {
184
185
186
187
188
189
190
191
192
193
194
195
        "val-matrix": val_matrix,
        "test-matrix": test_matrix,
        "item-texts": movie_textual_dataset,
        "item-images": None,
        "user-type": "user",
        "item-type": "movie",
        "user-to-item-type": "watched",
        "item-to-user-type": "watched-by",
        "timestamp-edge-column": "timestamp",
    }

    with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
196
        pickle.dump(dataset, f)