import argparse from ogb.linkproppred import * from ogb.nodeproppred import * from dgl.data import CitationGraphDataset def load_graph(name): cite_graphs = ["cora", "citeseer", "pubmed"] if name in cite_graphs: dataset = CitationGraphDataset(name) graph = dataset[0] nodes = graph.nodes() y = graph.ndata["label"] train_mask = graph.ndata["train_mask"] val_mask = graph.ndata["test_mask"] nodes_train, y_train = nodes[train_mask], y[train_mask] nodes_val, y_val = nodes[val_mask], y[val_mask] eval_set = [(nodes_train, y_train), (nodes_val, y_val)] elif name.startswith("ogbn"): dataset = DglNodePropPredDataset(name) graph, y = dataset[0] split_nodes = dataset.get_idx_split() nodes = graph.nodes() train_idx = split_nodes["train"] val_idx = split_nodes["valid"] nodes_train, y_train = nodes[train_idx], y[train_idx] nodes_val, y_val = nodes[val_idx], y[val_idx] eval_set = [(nodes_train, y_train), (nodes_val, y_val)] else: raise ValueError("Dataset name error!") return graph, eval_set def parse_arguments(): """ Parse arguments """ parser = argparse.ArgumentParser(description="Node2vec") parser.add_argument("--dataset", type=str, default="cora") # 'train' for training node2vec model, 'time' for testing speed of random walk parser.add_argument("--task", type=str, default="train") parser.add_argument("--runs", type=int, default=10) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--embedding_dim", type=int, default=128) parser.add_argument("--walk_length", type=int, default=50) parser.add_argument("--p", type=float, default=0.25) parser.add_argument("--q", type=float, default=4.0) parser.add_argument("--num_walks", type=int, default=10) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch_size", type=int, default=128) args = parser.parse_args() return args