main.py 1.73 KB
Newer Older
1
import argparse
2

3
import torch
4

5
import dgl
6
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
7
from dgl.nn import LabelPropagation
8
9
10
11


def main():
    # check cuda
12
13
14
15
16
    device = (
        f"cuda:{args.gpu}"
        if torch.cuda.is_available() and args.gpu >= 0
        else "cpu"
    )
17
18

    # load data
19
    if args.dataset == "Cora":
20
        dataset = CoraGraphDataset()
21
    elif args.dataset == "Citeseer":
22
        dataset = CiteseerGraphDataset()
23
    elif args.dataset == "Pubmed":
24
25
        dataset = PubmedGraphDataset()
    else:
26
27
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

28
29
30
    g = dataset[0]
    g = dgl.add_self_loop(g)

31
    labels = g.ndata.pop("label").to(device).long()
32
33

    # load masks for train / test, valid is not used.
34
35
    train_mask = g.ndata.pop("train_mask")
    test_mask = g.ndata.pop("test_mask")
36
37
38
39
40

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    g = g.to(device)
41

42
43
44
45
    # label propagation
    lp = LabelPropagation(args.num_layers, args.alpha)
    logits = lp(g, labels, mask=train_idx)

46
47
48
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
49
50
51
    print("Test Acc {:.4f}".format(test_acc))


52
if __name__ == "__main__":
53
54
55
    """
    Label Propagation Hyperparameters
    """
56
57
58
59
60
    parser = argparse.ArgumentParser(description="LP")
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--dataset", type=str, default="Cora")
    parser.add_argument("--num-layers", type=int, default=10)
    parser.add_argument("--alpha", type=float, default=0.5)
61
62
63
64
65

    args = parser.parse_args()
    print(args)

    main()