import os import pickle as pkl import random import dgl import numpy as np import torch from torch.utils.data import DataLoader, Dataset # Split data into train/eval/test def split_data(hg, etype_name): src, dst = hg.edges(etype=etype_name) user_item_src = src.numpy().tolist() user_item_dst = dst.numpy().tolist() num_link = len(user_item_src) pos_label = [1] * num_link pos_data = list(zip(user_item_src, user_item_dst, pos_label)) ui_adj = np.array(hg.adj_external(etype=etype_name).to_dense()) full_idx = np.where(ui_adj == 0) sample = random.sample(range(0, len(full_idx[0])), num_link) neg_label = [0] * num_link neg_data = list(zip(full_idx[0][sample], full_idx[1][sample], neg_label)) full_data = pos_data + neg_data random.shuffle(full_data) train_size = int(len(full_data) * 0.6) eval_size = int(len(full_data) * 0.2) test_size = len(full_data) - train_size - eval_size train_data = full_data[:train_size] eval_data = full_data[train_size : train_size + eval_size] test_data = full_data[ train_size + eval_size : train_size + eval_size + test_size ] train_data = np.array(train_data) eval_data = np.array(eval_data) test_data = np.array(test_data) return train_data, eval_data, test_data def process_amazon(root_path): # User-Item 3584 2753 50903 UIUI # Item-View 2753 3857 5694 UIVI # Item-Brand 2753 334 2753 UIBI # Item-Category 2753 22 5508 UICI # Construct graph from raw data. # load data of amazon data_path = os.path.join(root_path, "Amazon") if not (os.path.exists(data_path)): print( "Can not find amazon in {}, please download the dataset first.".format( data_path ) ) # item_view item_view_src = [] item_view_dst = [] with open(os.path.join(data_path, "item_view.dat")) as fin: for line in fin.readlines(): _line = line.strip().split(",") item, view = int(_line[0]), int(_line[1]) item_view_src.append(item) item_view_dst.append(view) # user_item user_item_src = [] user_item_dst = [] with open(os.path.join(data_path, "user_item.dat")) as fin: for line in fin.readlines(): _line = line.strip().split("\t") user, item, rate = int(_line[0]), int(_line[1]), int(_line[2]) if rate > 3: user_item_src.append(user) user_item_dst.append(item) # item_brand item_brand_src = [] item_brand_dst = [] with open(os.path.join(data_path, "item_brand.dat")) as fin: for line in fin.readlines(): _line = line.strip().split(",") item, brand = int(_line[0]), int(_line[1]) item_brand_src.append(item) item_brand_dst.append(brand) # item_category item_category_src = [] item_category_dst = [] with open(os.path.join(data_path, "item_category.dat")) as fin: for line in fin.readlines(): _line = line.strip().split(",") item, category = int(_line[0]), int(_line[1]) item_category_src.append(item) item_category_dst.append(category) # build graph hg = dgl.heterograph( { ("item", "iv", "view"): (item_view_src, item_view_dst), ("view", "vi", "item"): (item_view_dst, item_view_src), ("user", "ui", "item"): (user_item_src, user_item_dst), ("item", "iu", "user"): (user_item_dst, user_item_src), ("item", "ib", "brand"): (item_brand_src, item_brand_dst), ("brand", "bi", "item"): (item_brand_dst, item_brand_src), ("item", "ic", "category"): (item_category_src, item_category_dst), ("category", "ci", "item"): (item_category_dst, item_category_src), } ) print("Graph constructed.") # Split data into train/eval/test train_data, eval_data, test_data = split_data(hg, "ui") # delete the positive edges in eval/test data in the original graph train_pos = np.nonzero(train_data[:, 2]) train_pos_idx = train_pos[0] user_item_src_processed = train_data[train_pos_idx, 0] user_item_dst_processed = train_data[train_pos_idx, 1] edges_dict = { ("item", "iv", "view"): (item_view_src, item_view_dst), ("view", "vi", "item"): (item_view_dst, item_view_src), ("user", "ui", "item"): ( user_item_src_processed, user_item_dst_processed, ), ("item", "iu", "user"): ( user_item_dst_processed, user_item_src_processed, ), ("item", "ib", "brand"): (item_brand_src, item_brand_dst), ("brand", "bi", "item"): (item_brand_dst, item_brand_src), ("item", "ic", "category"): (item_category_src, item_category_dst), ("category", "ci", "item"): (item_category_dst, item_category_src), } nodes_dict = { "user": hg.num_nodes("user"), "item": hg.num_nodes("item"), "view": hg.num_nodes("view"), "brand": hg.num_nodes("brand"), "category": hg.num_nodes("category"), } hg_processed = dgl.heterograph( data_dict=edges_dict, num_nodes_dict=nodes_dict ) print("Graph processed.") # save the processed data with open(os.path.join(root_path, "amazon_hg.pkl"), "wb") as file: pkl.dump(hg_processed, file) with open(os.path.join(root_path, "amazon_train.pkl"), "wb") as file: pkl.dump(train_data, file) with open(os.path.join(root_path, "amazon_test.pkl"), "wb") as file: pkl.dump(test_data, file) with open(os.path.join(root_path, "amazon_eval.pkl"), "wb") as file: pkl.dump(eval_data, file) return hg_processed, train_data, eval_data, test_data def process_movielens(root_path): # User-Movie 943 1682 100000 UMUM # User-Age 943 8 943 UAUM # User-Occupation 943 21 943 UOUM # Movie-Genre 1682 18 2861 UMGM data_path = os.path.join(root_path, "Movielens") if not (os.path.exists(data_path)): print( "Can not find movielens in {}, please download the dataset first.".format( data_path ) ) # Construct graph from raw data. # movie_genre movie_genre_src = [] movie_genre_dst = [] with open(os.path.join(data_path, "movie_genre.dat")) as fin: for line in fin.readlines(): _line = line.strip().split("\t") movie, genre = int(_line[0]), int(_line[1]) movie_genre_src.append(movie) movie_genre_dst.append(genre) # user_movie user_movie_src = [] user_movie_dst = [] with open(os.path.join(data_path, "user_movie.dat")) as fin: for line in fin.readlines(): _line = line.strip().split("\t") user, item, rate = int(_line[0]), int(_line[1]), int(_line[2]) if rate > 3: user_movie_src.append(user) user_movie_dst.append(item) # user_occupation user_occupation_src = [] user_occupation_dst = [] with open(os.path.join(data_path, "user_occupation.dat")) as fin: for line in fin.readlines(): _line = line.strip().split("\t") user, occupation = int(_line[0]), int(_line[1]) user_occupation_src.append(user) user_occupation_dst.append(occupation) # user_age user_age_src = [] user_age_dst = [] with open(os.path.join(data_path, "user_age.dat")) as fin: for line in fin.readlines(): _line = line.strip().split("\t") user, age = int(_line[0]), int(_line[1]) user_age_src.append(user) user_age_dst.append(age) # build graph hg = dgl.heterograph( { ("movie", "mg", "genre"): (movie_genre_src, movie_genre_dst), ("genre", "gm", "movie"): (movie_genre_dst, movie_genre_src), ("user", "um", "movie"): (user_movie_src, user_movie_dst), ("movie", "mu", "user"): (user_movie_dst, user_movie_src), ("user", "uo", "occupation"): ( user_occupation_src, user_occupation_dst, ), ("occupation", "ou", "user"): ( user_occupation_dst, user_occupation_src, ), ("user", "ua", "age"): (user_age_src, user_age_dst), ("age", "au", "user"): (user_age_dst, user_age_src), } ) print("Graph constructed.") # Split data into train/eval/test train_data, eval_data, test_data = split_data(hg, "um") # delete the positive edges in eval/test data in the original graph train_pos = np.nonzero(train_data[:, 2]) train_pos_idx = train_pos[0] user_movie_src_processed = train_data[train_pos_idx, 0] user_movie_dst_processed = train_data[train_pos_idx, 1] edges_dict = { ("movie", "mg", "genre"): (movie_genre_src, movie_genre_dst), ("genre", "gm", "movie"): (movie_genre_dst, movie_genre_src), ("user", "um", "movie"): ( user_movie_src_processed, user_movie_dst_processed, ), ("movie", "mu", "user"): ( user_movie_dst_processed, user_movie_src_processed, ), ("user", "uo", "occupation"): ( user_occupation_src, user_occupation_dst, ), ("occupation", "ou", "user"): ( user_occupation_dst, user_occupation_src, ), ("user", "ua", "age"): (user_age_src, user_age_dst), ("age", "au", "user"): (user_age_dst, user_age_src), } nodes_dict = { "user": hg.num_nodes("user"), "movie": hg.num_nodes("movie"), "genre": hg.num_nodes("genre"), "occupation": hg.num_nodes("occupation"), "age": hg.num_nodes("age"), } hg_processed = dgl.heterograph( data_dict=edges_dict, num_nodes_dict=nodes_dict ) print("Graph processed.") # save the processed data with open(os.path.join(root_path, "movielens_hg.pkl"), "wb") as file: pkl.dump(hg_processed, file) with open(os.path.join(root_path, "movielens_train.pkl"), "wb") as file: pkl.dump(train_data, file) with open(os.path.join(root_path, "movielens_test.pkl"), "wb") as file: pkl.dump(test_data, file) with open(os.path.join(root_path, "movielens_eval.pkl"), "wb") as file: pkl.dump(eval_data, file) return hg_processed, train_data, eval_data, test_data class MyDataset(Dataset): def __init__(self, triple): self.triple = triple self.len = self.triple.shape[0] def __getitem__(self, index): return ( self.triple[index, 0], self.triple[index, 1], self.triple[index, 2].float(), ) def __len__(self): return self.len def load_data(dataset, batch_size=128, num_workers=10, root_path="./data"): if os.path.exists(os.path.join(root_path, dataset + "_train.pkl")): g_file = open(os.path.join(root_path, dataset + "_hg.pkl"), "rb") hg = pkl.load(g_file) g_file.close() train_set_file = open( os.path.join(root_path, dataset + "_train.pkl"), "rb" ) train_set = pkl.load(train_set_file) train_set_file.close() test_set_file = open( os.path.join(root_path, dataset + "_test.pkl"), "rb" ) test_set = pkl.load(test_set_file) test_set_file.close() eval_set_file = open( os.path.join(root_path, dataset + "_eval.pkl"), "rb" ) eval_set = pkl.load(eval_set_file) eval_set_file.close() else: if dataset == "movielens": hg, train_set, eval_set, test_set = process_movielens(root_path) elif dataset == "amazon": hg, train_set, eval_set, test_set = process_amazon(root_path) else: print("Available datasets: movielens, amazon.") raise NotImplementedError if dataset == "movielens": meta_paths = { "user": [["um", "mu"]], "movie": [["mu", "um"], ["mg", "gm"]], } user_key = "user" item_key = "movie" elif dataset == "amazon": meta_paths = { "user": [["ui", "iu"]], "item": [["iu", "ui"], ["ic", "ci"], ["ib", "bi"], ["iv", "vi"]], } user_key = "user" item_key = "item" else: print("Available datasets: movielens, amazon.") raise NotImplementedError train_set = torch.Tensor(train_set).long() eval_set = torch.Tensor(eval_set).long() test_set = torch.Tensor(test_set).long() train_set = MyDataset(train_set) train_loader = DataLoader( dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, ) eval_set = MyDataset(eval_set) eval_loader = DataLoader( dataset=eval_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, ) test_set = MyDataset(test_set) test_loader = DataLoader( dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, ) return ( hg, train_loader, eval_loader, test_loader, meta_paths, user_key, item_key, )