import argparse import torch import dgl from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset from dgl.nn import LabelPropagation def main(): # check cuda device = ( f"cuda:{args.gpu}" if torch.cuda.is_available() and args.gpu >= 0 else "cpu" ) # load data if args.dataset == "Cora": dataset = CoraGraphDataset() elif args.dataset == "Citeseer": dataset = CiteseerGraphDataset() elif args.dataset == "Pubmed": dataset = PubmedGraphDataset() else: raise ValueError("Dataset {} is invalid.".format(args.dataset)) g = dataset[0] g = dgl.add_self_loop(g) labels = g.ndata.pop("label").to(device).long() # load masks for train / test, valid is not used. train_mask = g.ndata.pop("train_mask") test_mask = g.ndata.pop("test_mask") train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device) test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device) g = g.to(device) # label propagation lp = LabelPropagation(args.num_layers, args.alpha) logits = lp(g, labels, mask=train_idx) test_acc = torch.sum( logits[test_idx].argmax(dim=1) == labels[test_idx] ).item() / len(test_idx) print("Test Acc {:.4f}".format(test_acc)) if __name__ == "__main__": """ Label Propagation Hyperparameters """ parser = argparse.ArgumentParser(description="LP") parser.add_argument("--gpu", type=int, default=0) parser.add_argument("--dataset", type=str, default="Cora") parser.add_argument("--num-layers", type=int, default=10) parser.add_argument("--alpha", type=float, default=0.5) args = parser.parse_args() print(args) main()