data_loader.py 13.2 KB
Newer Older
KounianhuaDu's avatar
KounianhuaDu committed
1
2
3
4
import os
import pickle as pkl
import random

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
9
10
11
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset


KounianhuaDu's avatar
KounianhuaDu committed
12
13
14
15
16
# Split data into train/eval/test
def split_data(hg, etype_name):
    src, dst = hg.edges(etype=etype_name)
    user_item_src = src.numpy().tolist()
    user_item_dst = dst.numpy().tolist()
17

KounianhuaDu's avatar
KounianhuaDu committed
18
    num_link = len(user_item_src)
19
20
    pos_label = [1] * num_link
    pos_data = list(zip(user_item_src, user_item_dst, pos_label))
KounianhuaDu's avatar
KounianhuaDu committed
21
22

    ui_adj = np.array(hg.adj(etype=etype_name).to_dense())
23
    full_idx = np.where(ui_adj == 0)
KounianhuaDu's avatar
KounianhuaDu committed
24
25

    sample = random.sample(range(0, len(full_idx[0])), num_link)
26
27
28
    neg_label = [0] * num_link
    neg_data = list(zip(full_idx[0][sample], full_idx[1][sample], neg_label))

KounianhuaDu's avatar
KounianhuaDu committed
29
30
31
32
33
34
35
    full_data = pos_data + neg_data
    random.shuffle(full_data)

    train_size = int(len(full_data) * 0.6)
    eval_size = int(len(full_data) * 0.2)
    test_size = len(full_data) - train_size - eval_size
    train_data = full_data[:train_size]
36
37
38
39
    eval_data = full_data[train_size : train_size + eval_size]
    test_data = full_data[
        train_size + eval_size : train_size + eval_size + test_size
    ]
KounianhuaDu's avatar
KounianhuaDu committed
40
41
42
    train_data = np.array(train_data)
    eval_data = np.array(eval_data)
    test_data = np.array(test_data)
43

KounianhuaDu's avatar
KounianhuaDu committed
44
    return train_data, eval_data, test_data
45

KounianhuaDu's avatar
KounianhuaDu committed
46
47
48
49
50
51

def process_amazon(root_path):
    # User-Item 3584 2753 50903 UIUI
    # Item-View 2753 3857 5694 UIVI
    # Item-Brand 2753 334 2753 UIBI
    # Item-Category 2753 22 5508 UICI
52
53

    # Construct graph from raw data.
KounianhuaDu's avatar
KounianhuaDu committed
54
    # load data of amazon
55
    data_path = os.path.join(root_path, "Amazon")
KounianhuaDu's avatar
KounianhuaDu committed
56
    if not (os.path.exists(data_path)):
57
58
59
60
61
62
        print(
            "Can not find amazon in {}, please download the dataset first.".format(
                data_path
            )
        )

KounianhuaDu's avatar
KounianhuaDu committed
63
    # item_view
64
65
66
    item_view_src = []
    item_view_dst = []
    with open(os.path.join(data_path, "item_view.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
67
        for line in fin.readlines():
68
69
            _line = line.strip().split(",")
            item, view = int(_line[0]), int(_line[1])
KounianhuaDu's avatar
KounianhuaDu committed
70
71
72
73
            item_view_src.append(item)
            item_view_dst.append(view)

    # user_item
74
75
76
    user_item_src = []
    user_item_dst = []
    with open(os.path.join(data_path, "user_item.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
77
        for line in fin.readlines():
78
            _line = line.strip().split("\t")
KounianhuaDu's avatar
KounianhuaDu committed
79
80
81
82
83
84
            user, item, rate = int(_line[0]), int(_line[1]), int(_line[2])
            if rate > 3:
                user_item_src.append(user)
                user_item_dst.append(item)

    # item_brand
85
86
87
    item_brand_src = []
    item_brand_dst = []
    with open(os.path.join(data_path, "item_brand.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
88
        for line in fin.readlines():
89
90
            _line = line.strip().split(",")
            item, brand = int(_line[0]), int(_line[1])
KounianhuaDu's avatar
KounianhuaDu committed
91
92
93
94
            item_brand_src.append(item)
            item_brand_dst.append(brand)

    # item_category
95
96
97
    item_category_src = []
    item_category_dst = []
    with open(os.path.join(data_path, "item_category.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
98
        for line in fin.readlines():
99
100
            _line = line.strip().split(",")
            item, category = int(_line[0]), int(_line[1])
KounianhuaDu's avatar
KounianhuaDu committed
101
102
103
            item_category_src.append(item)
            item_category_dst.append(category)

104
105
106
107
108
109
110
111
112
113
114
115
116
    # build graph
    hg = dgl.heterograph(
        {
            ("item", "iv", "view"): (item_view_src, item_view_dst),
            ("view", "vi", "item"): (item_view_dst, item_view_src),
            ("user", "ui", "item"): (user_item_src, user_item_dst),
            ("item", "iu", "user"): (user_item_dst, user_item_src),
            ("item", "ib", "brand"): (item_brand_src, item_brand_dst),
            ("brand", "bi", "item"): (item_brand_dst, item_brand_src),
            ("item", "ic", "category"): (item_category_src, item_category_dst),
            ("category", "ci", "item"): (item_category_dst, item_category_src),
        }
    )
KounianhuaDu's avatar
KounianhuaDu committed
117
118
119
120

    print("Graph constructed.")

    # Split data into train/eval/test
121
    train_data, eval_data, test_data = split_data(hg, "ui")
KounianhuaDu's avatar
KounianhuaDu committed
122

123
124
    # delete the positive edges in eval/test data in the original graph
    train_pos = np.nonzero(train_data[:, 2])
KounianhuaDu's avatar
KounianhuaDu committed
125
126
127
128
    train_pos_idx = train_pos[0]
    user_item_src_processed = train_data[train_pos_idx, 0]
    user_item_dst_processed = train_data[train_pos_idx, 1]
    edges_dict = {
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        ("item", "iv", "view"): (item_view_src, item_view_dst),
        ("view", "vi", "item"): (item_view_dst, item_view_src),
        ("user", "ui", "item"): (
            user_item_src_processed,
            user_item_dst_processed,
        ),
        ("item", "iu", "user"): (
            user_item_dst_processed,
            user_item_src_processed,
        ),
        ("item", "ib", "brand"): (item_brand_src, item_brand_dst),
        ("brand", "bi", "item"): (item_brand_dst, item_brand_src),
        ("item", "ic", "category"): (item_category_src, item_category_dst),
        ("category", "ci", "item"): (item_category_dst, item_category_src),
KounianhuaDu's avatar
KounianhuaDu committed
143
144
    }
    nodes_dict = {
145
146
147
148
149
        "user": hg.num_nodes("user"),
        "item": hg.num_nodes("item"),
        "view": hg.num_nodes("view"),
        "brand": hg.num_nodes("brand"),
        "category": hg.num_nodes("category"),
KounianhuaDu's avatar
KounianhuaDu committed
150
    }
151
152
153
    hg_processed = dgl.heterograph(
        data_dict=edges_dict, num_nodes_dict=nodes_dict
    )
KounianhuaDu's avatar
KounianhuaDu committed
154
155
    print("Graph processed.")

156
157
    # save the processed data
    with open(os.path.join(root_path, "amazon_hg.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
158
        pkl.dump(hg_processed, file)
159
    with open(os.path.join(root_path, "amazon_train.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
160
        pkl.dump(train_data, file)
161
    with open(os.path.join(root_path, "amazon_test.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
162
        pkl.dump(test_data, file)
163
    with open(os.path.join(root_path, "amazon_eval.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
164
165
166
167
168
169
170
171
172
173
174
        pkl.dump(eval_data, file)

    return hg_processed, train_data, eval_data, test_data


def process_movielens(root_path):
    # User-Movie 943 1682 100000 UMUM
    # User-Age 943 8 943 UAUM
    # User-Occupation 943 21 943 UOUM
    # Movie-Genre 1682 18 2861 UMGM

175
    data_path = os.path.join(root_path, "Movielens")
KounianhuaDu's avatar
KounianhuaDu committed
176
    if not (os.path.exists(data_path)):
177
178
179
180
181
        print(
            "Can not find movielens in {}, please download the dataset first.".format(
                data_path
            )
        )
KounianhuaDu's avatar
KounianhuaDu committed
182

183
    # Construct graph from raw data.
KounianhuaDu's avatar
KounianhuaDu committed
184
    # movie_genre
185
186
187
    movie_genre_src = []
    movie_genre_dst = []
    with open(os.path.join(data_path, "movie_genre.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
188
        for line in fin.readlines():
189
            _line = line.strip().split("\t")
KounianhuaDu's avatar
KounianhuaDu committed
190
191
192
193
194
            movie, genre = int(_line[0]), int(_line[1])
            movie_genre_src.append(movie)
            movie_genre_dst.append(genre)

    # user_movie
195
196
197
    user_movie_src = []
    user_movie_dst = []
    with open(os.path.join(data_path, "user_movie.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
198
        for line in fin.readlines():
199
            _line = line.strip().split("\t")
KounianhuaDu's avatar
KounianhuaDu committed
200
201
202
203
204
205
            user, item, rate = int(_line[0]), int(_line[1]), int(_line[2])
            if rate > 3:
                user_movie_src.append(user)
                user_movie_dst.append(item)

    # user_occupation
206
207
208
    user_occupation_src = []
    user_occupation_dst = []
    with open(os.path.join(data_path, "user_occupation.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
209
        for line in fin.readlines():
210
            _line = line.strip().split("\t")
KounianhuaDu's avatar
KounianhuaDu committed
211
212
213
214
215
            user, occupation = int(_line[0]), int(_line[1])
            user_occupation_src.append(user)
            user_occupation_dst.append(occupation)

    # user_age
216
217
218
    user_age_src = []
    user_age_dst = []
    with open(os.path.join(data_path, "user_age.dat")) as fin:
KounianhuaDu's avatar
KounianhuaDu committed
219
        for line in fin.readlines():
220
            _line = line.strip().split("\t")
KounianhuaDu's avatar
KounianhuaDu committed
221
222
223
224
            user, age = int(_line[0]), int(_line[1])
            user_age_src.append(user)
            user_age_dst.append(age)

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    # build graph
    hg = dgl.heterograph(
        {
            ("movie", "mg", "genre"): (movie_genre_src, movie_genre_dst),
            ("genre", "gm", "movie"): (movie_genre_dst, movie_genre_src),
            ("user", "um", "movie"): (user_movie_src, user_movie_dst),
            ("movie", "mu", "user"): (user_movie_dst, user_movie_src),
            ("user", "uo", "occupation"): (
                user_occupation_src,
                user_occupation_dst,
            ),
            ("occupation", "ou", "user"): (
                user_occupation_dst,
                user_occupation_src,
            ),
            ("user", "ua", "age"): (user_age_src, user_age_dst),
            ("age", "au", "user"): (user_age_dst, user_age_src),
        }
    )
KounianhuaDu's avatar
KounianhuaDu committed
244
245
246
247

    print("Graph constructed.")

    # Split data into train/eval/test
248
    train_data, eval_data, test_data = split_data(hg, "um")
KounianhuaDu's avatar
KounianhuaDu committed
249

250
251
    # delete the positive edges in eval/test data in the original graph
    train_pos = np.nonzero(train_data[:, 2])
KounianhuaDu's avatar
KounianhuaDu committed
252
253
254
255
    train_pos_idx = train_pos[0]
    user_movie_src_processed = train_data[train_pos_idx, 0]
    user_movie_dst_processed = train_data[train_pos_idx, 1]
    edges_dict = {
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        ("movie", "mg", "genre"): (movie_genre_src, movie_genre_dst),
        ("genre", "gm", "movie"): (movie_genre_dst, movie_genre_src),
        ("user", "um", "movie"): (
            user_movie_src_processed,
            user_movie_dst_processed,
        ),
        ("movie", "mu", "user"): (
            user_movie_dst_processed,
            user_movie_src_processed,
        ),
        ("user", "uo", "occupation"): (
            user_occupation_src,
            user_occupation_dst,
        ),
        ("occupation", "ou", "user"): (
            user_occupation_dst,
            user_occupation_src,
        ),
        ("user", "ua", "age"): (user_age_src, user_age_dst),
        ("age", "au", "user"): (user_age_dst, user_age_src),
KounianhuaDu's avatar
KounianhuaDu committed
276
277
    }
    nodes_dict = {
278
279
280
281
282
        "user": hg.num_nodes("user"),
        "movie": hg.num_nodes("movie"),
        "genre": hg.num_nodes("genre"),
        "occupation": hg.num_nodes("occupation"),
        "age": hg.num_nodes("age"),
KounianhuaDu's avatar
KounianhuaDu committed
283
    }
284
285
286
    hg_processed = dgl.heterograph(
        data_dict=edges_dict, num_nodes_dict=nodes_dict
    )
KounianhuaDu's avatar
KounianhuaDu committed
287
288
    print("Graph processed.")

289
290
    # save the processed data
    with open(os.path.join(root_path, "movielens_hg.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
291
        pkl.dump(hg_processed, file)
292
    with open(os.path.join(root_path, "movielens_train.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
293
        pkl.dump(train_data, file)
294
    with open(os.path.join(root_path, "movielens_test.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
295
        pkl.dump(test_data, file)
296
    with open(os.path.join(root_path, "movielens_eval.pkl"), "wb") as file:
KounianhuaDu's avatar
KounianhuaDu committed
297
298
299
300
301
        pkl.dump(eval_data, file)

    return hg_processed, train_data, eval_data, test_data


302
class MyDataset(Dataset):
KounianhuaDu's avatar
KounianhuaDu committed
303
304
305
    def __init__(self, triple):
        self.triple = triple
        self.len = self.triple.shape[0]
306

KounianhuaDu's avatar
KounianhuaDu committed
307
    def __getitem__(self, index):
308
309
310
311
312
        return (
            self.triple[index, 0],
            self.triple[index, 1],
            self.triple[index, 2].float(),
        )
KounianhuaDu's avatar
KounianhuaDu committed
313
314
315
316

    def __len__(self):
        return self.len

317
318
319
320

def load_data(dataset, batch_size=128, num_workers=10, root_path="./data"):
    if os.path.exists(os.path.join(root_path, dataset + "_train.pkl")):
        g_file = open(os.path.join(root_path, dataset + "_hg.pkl"), "rb")
KounianhuaDu's avatar
KounianhuaDu committed
321
322
        hg = pkl.load(g_file)
        g_file.close()
323
324
325
        train_set_file = open(
            os.path.join(root_path, dataset + "_train.pkl"), "rb"
        )
KounianhuaDu's avatar
KounianhuaDu committed
326
327
        train_set = pkl.load(train_set_file)
        train_set_file.close()
328
329
330
        test_set_file = open(
            os.path.join(root_path, dataset + "_test.pkl"), "rb"
        )
KounianhuaDu's avatar
KounianhuaDu committed
331
332
        test_set = pkl.load(test_set_file)
        test_set_file.close()
333
334
335
        eval_set_file = open(
            os.path.join(root_path, dataset + "_eval.pkl"), "rb"
        )
KounianhuaDu's avatar
KounianhuaDu committed
336
337
        eval_set = pkl.load(eval_set_file)
        eval_set_file.close()
338
339
    else:
        if dataset == "movielens":
KounianhuaDu's avatar
KounianhuaDu committed
340
            hg, train_set, eval_set, test_set = process_movielens(root_path)
341
        elif dataset == "amazon":
KounianhuaDu's avatar
KounianhuaDu committed
342
343
            hg, train_set, eval_set, test_set = process_amazon(root_path)
        else:
344
            print("Available datasets: movielens, amazon.")
KounianhuaDu's avatar
KounianhuaDu committed
345
346
            raise NotImplementedError

347
    if dataset == "movielens":
KounianhuaDu's avatar
KounianhuaDu committed
348
        meta_paths = {
349
350
            "user": [["um", "mu"]],
            "movie": [["mu", "um"], ["mg", "gm"]],
KounianhuaDu's avatar
KounianhuaDu committed
351
        }
352
353
354
        user_key = "user"
        item_key = "movie"
    elif dataset == "amazon":
KounianhuaDu's avatar
KounianhuaDu committed
355
        meta_paths = {
356
357
            "user": [["ui", "iu"]],
            "item": [["iu", "ui"], ["ic", "ci"], ["ib", "bi"], ["iv", "vi"]],
KounianhuaDu's avatar
KounianhuaDu committed
358
        }
359
360
        user_key = "user"
        item_key = "item"
KounianhuaDu's avatar
KounianhuaDu committed
361
    else:
362
        print("Available datasets: movielens, amazon.")
KounianhuaDu's avatar
KounianhuaDu committed
363
        raise NotImplementedError
364

KounianhuaDu's avatar
KounianhuaDu committed
365
366
367
    train_set = torch.Tensor(train_set).long()
    eval_set = torch.Tensor(eval_set).long()
    test_set = torch.Tensor(test_set).long()
368

KounianhuaDu's avatar
KounianhuaDu committed
369
    train_set = MyDataset(train_set)
370
371
372
373
374
375
    train_loader = DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
KounianhuaDu's avatar
KounianhuaDu committed
376
    eval_set = MyDataset(eval_set)
377
378
379
380
381
382
    eval_loader = DataLoader(
        dataset=eval_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
KounianhuaDu's avatar
KounianhuaDu committed
383
    test_set = MyDataset(test_set)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    test_loader = DataLoader(
        dataset=test_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )

    return (
        hg,
        train_loader,
        eval_loader,
        test_loader,
        meta_paths,
        user_key,
        item_key,
    )