import os import numpy as np import scipy.sparse as sp import pickle import torch from torch.utils.data import DataLoader from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_archive import random import time import dgl from utils import shuffle_walks def ReadTxtNet(file_path="", undirected=True): """ Read the txt network file. Notations: The network is unweighted. Parameters ---------- file_path str : path of network file undirected bool : whether the edges are undirected Return ------ net dict : a dict recording the connections in the graph node2id dict : a dict mapping the nodes to their embedding indices id2node dict : a dict mapping nodes embedding indices to the nodes """ if file_path == 'youtube' or file_path == 'blog': name = file_path dir = get_download_dir() zip_file_path='{}/{}.zip'.format(dir, name) download(_get_dgl_url(os.path.join('dataset/DeepWalk/', '{}.zip'.format(file_path))), path=zip_file_path) extract_archive(zip_file_path, '{}/{}'.format(dir, name)) file_path = "{}/{}/{}-net.txt".format(dir, name, name) node2id = {} id2node = {} cid = 0 src = [] dst = [] weight = [] net = {} with open(file_path, "r") as f: for line in f.readlines(): tup = list(map(int, line.strip().split(" "))) assert len(tup) in [2, 3], "The format of network file is unrecognizable." if len(tup) == 3: n1, n2, w = tup elif len(tup) == 2: n1, n2 = tup w = 1 if n1 not in node2id: node2id[n1] = cid id2node[cid] = n1 cid += 1 if n2 not in node2id: node2id[n2] = cid id2node[cid] = n2 cid += 1 n1 = node2id[n1] n2 = node2id[n2] if n1 not in net: net[n1] = {n2: w} src.append(n1) dst.append(n2) weight.append(w) elif n2 not in net[n1]: net[n1][n2] = w src.append(n1) dst.append(n2) weight.append(w) if undirected: if n2 not in net: net[n2] = {n1: w} src.append(n2) dst.append(n1) weight.append(w) elif n1 not in net[n2]: net[n2][n1] = w src.append(n2) dst.append(n1) weight.append(w) print("node num: %d" % len(net)) print("edge num: %d" % len(src)) assert max(net.keys()) == len(net) - 1, "error reading net, quit" sm = sp.coo_matrix( (np.array(weight), (src, dst)), dtype=np.float32) return net, node2id, id2node, sm def net2graph(net_sm): """ Transform the network to DGL graph Return ------ G DGLGraph : graph by DGL """ start = time.time() G = dgl.DGLGraph(net_sm) end = time.time() t = end - start print("Building DGLGraph in %.2fs" % t) return G def make_undirected(G): G.readonly(False) G.add_edges(G.edges()[1], G.edges()[0]) return G def find_connected_nodes(G): nodes = [] for n in G.nodes(): if G.out_degree(n) > 0: nodes.append(n.item()) return nodes class DeepwalkDataset: def __init__(self, net_file, map_file, walk_length, window_size, num_walks, batch_size, negative=5, gpus=[0], fast_neg=True, ogbl_name="", load_from_ogbl=False, ): """ This class has the following functions: 1. Transform the txt network file into DGL graph; 2. Generate random walk sequences for the trainer; 3. Provide the negative table if the user hopes to sample negative nodes according to nodes' degrees; Parameter --------- net_file str : path of the txt network file walk_length int : number of nodes in a sequence window_size int : context window size num_walks int : number of walks for each node batch_size int : number of node sequences in each batch negative int : negative samples for each positve node pair fast_neg bool : whether do negative sampling inside a batch """ self.walk_length = walk_length self.window_size = window_size self.num_walks = num_walks self.batch_size = batch_size self.negative = negative self.num_procs = len(gpus) self.fast_neg = fast_neg if load_from_ogbl: assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)." from load_dataset import load_from_ogbl_with_name self.G = load_from_ogbl_with_name(ogbl_name) self.G = make_undirected(self.G) else: self.net, self.node2id, self.id2node, self.sm = ReadTxtNet(net_file) self.save_mapping(map_file) self.G = net2graph(self.sm) self.num_nodes = self.G.number_of_nodes() # random walk seeds start = time.time() self.valid_seeds = find_connected_nodes(self.G) if len(self.valid_seeds) != self.num_nodes: print("WARNING: The node ids are not serial. Some nodes are invalid.") seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks) self.seeds = torch.split(shuffle_walks(seeds), int(np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)), 0) end = time.time() t = end - start print("%d seeds in %.2fs" % (len(seeds), t)) # negative table for true negative sampling if not fast_neg: node_degree = np.array(list(map(lambda x: self.G.out_degree(x), self.valid_seeds))) node_degree = np.power(node_degree, 0.75) node_degree /= np.sum(node_degree) node_degree = np.array(node_degree * 1e8, dtype=np.int) self.neg_table = [] for idx, node in enumerate(self.valid_seeds): self.neg_table += [node] * node_degree[idx] self.neg_table_size = len(self.neg_table) self.neg_table = np.array(self.neg_table, dtype=np.long) del node_degree def create_sampler(self, i): """ create random walk sampler """ return DeepwalkSampler(self.G, self.seeds[i], self.walk_length) def save_mapping(self, map_file): """ save the mapping dict that maps node IDs to embedding indices """ with open(map_file, "wb") as f: pickle.dump(self.node2id, f) class DeepwalkSampler(object): def __init__(self, G, seeds, walk_length): """ random walk sampler Parameter --------- G dgl.Graph : the input graph seeds torch.LongTensor : starting nodes walk_length int : walk length """ self.G = G self.seeds = seeds self.walk_length = walk_length def sample(self, seeds): walks = dgl.contrib.sampling.random_walk(self.G, seeds, 1, self.walk_length-1) return walks