import argparse from dgl.data import CitationGraphDataset from ogb.nodeproppred import * from ogb.linkproppred import * def load_graph(name): cite_graphs = ['cora', 'citeseer', 'pubmed'] if name in cite_graphs: dataset = CitationGraphDataset(name) graph = dataset[0] nodes = graph.nodes() y = graph.ndata['label'] train_mask = graph.ndata['train_mask'] val_mask = graph.ndata['test_mask'] nodes_train, y_train = nodes[train_mask], y[train_mask] nodes_val, y_val = nodes[val_mask], y[val_mask] eval_set = [(nodes_train, y_train), (nodes_val, y_val)] elif name.startswith('ogbn'): dataset = DglNodePropPredDataset(name) graph, y = dataset[0] split_nodes = dataset.get_idx_split() nodes = graph.nodes() train_idx = split_nodes['train'] val_idx = split_nodes['valid'] nodes_train, y_train = nodes[train_idx], y[train_idx] nodes_val, y_val = nodes[val_idx], y[val_idx] eval_set = [(nodes_train, y_train), (nodes_val, y_val)] else: raise ValueError("Dataset name error!") return graph, eval_set def parse_arguments(): """ Parse arguments """ parser = argparse.ArgumentParser(description='Node2vec') parser.add_argument('--dataset', type=str, default='cora') # 'train' for training node2vec model, 'time' for testing speed of random walk parser.add_argument('--task', type=str, default='train') parser.add_argument('--runs', type=int, default=10) parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--embedding_dim', type=int, default=128) parser.add_argument('--walk_length', type=int, default=50) parser.add_argument('--p', type=float, default=0.25) parser.add_argument('--q', type=float, default=4.0) parser.add_argument('--num_walks', type=int, default=10) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=128) args = parser.parse_args() return args